1use dyn_clone::DynClone;
40use itertools::Itertools;
41use nalgebra::Complex;
42use parking_lot::RwLock;
43use rayon::prelude::*;
44use std::{
45 collections::HashSet,
46 fmt::{Debug, Display},
47 ops::{Add, Mul},
48 sync::Arc,
49};
50use tracing::{debug, info};
51
52use crate::{
53 convert,
54 dataset::{Dataset, Event},
55 errors::RustitudeError,
56 Field,
57};
58
59#[derive(Clone)]
61pub struct Parameter<F: Field> {
62 pub amplitude: String,
64 pub name: String,
66 pub index: Option<usize>,
69 pub fixed_index: Option<usize>,
72 pub initial: F,
75 pub bounds: (F, F),
78}
79impl<F: Field> Parameter<F> {
80 pub fn new(amplitude: &str, name: &str, index: usize) -> Self {
86 Self {
87 amplitude: amplitude.to_string(),
88 name: name.to_string(),
89 index: Some(index),
90 fixed_index: None,
91 initial: F::one(),
92 bounds: (F::neg_infinity(), F::infinity()),
93 }
94 }
95
96 pub const fn is_free(&self) -> bool {
98 self.index.is_some()
99 }
100
101 pub const fn is_fixed(&self) -> bool {
103 self.index.is_none()
104 }
105}
106
107impl<F: Field> Debug for Parameter<F> {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 if self.index.is_none() {
110 write!(
111 f,
112 "Parameter(name={}, value={} (fixed), bounds=({}, {}), parent={})",
113 self.name, self.initial, self.bounds.0, self.bounds.1, self.amplitude
114 )
115 } else {
116 write!(
117 f,
118 "Parameter(name={}, value={}, bounds=({}, {}), parent={})",
119 self.name, self.initial, self.bounds.0, self.bounds.1, self.amplitude
120 )
121 }
122 }
123}
124impl<F: Field> Display for Parameter<F> {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 write!(f, "{}", self.name)
127 }
128}
129
130pub trait Node<F: Field>: Sync + Send + DynClone {
248 fn precalculate(&mut self, _dataset: &Dataset<F>) -> Result<(), RustitudeError> {
260 Ok(())
261 }
262
263 fn calculate(&self, parameters: &[F], event: &Event<F>) -> Result<Complex<F>, RustitudeError>;
277
278 fn parameters(&self) -> Vec<String> {
285 vec![]
286 }
287
288 fn into_amplitude(self, name: &str) -> Amplitude<F>
290 where
291 Self: std::marker::Sized + 'static,
292 {
293 Amplitude::new(name, self)
294 }
295
296 fn named(self, name: &str) -> Amplitude<F>
299 where
300 Self: std::marker::Sized + 'static,
301 {
302 self.into_amplitude(name)
303 }
304
305 fn is_python_node(&self) -> bool {
310 false
311 }
312}
313
314dyn_clone::clone_trait_object!(<F> Node<F>);
315
316pub trait AmpLike<F: Field>: Send + Sync + Debug + Display + AsTree + DynClone {
322 fn walk(&self) -> Vec<Amplitude<F>>;
326 fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>>;
329 fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>>;
333 fn get_cloned_terms(&self) -> Option<Vec<Box<dyn AmpLike<F>>>> {
335 None
336 }
337 fn real(&self) -> Real<F>
339 where
340 Self: std::marker::Sized + 'static,
341 {
342 Real(dyn_clone::clone_box(self))
343 }
344 fn imag(&self) -> Imag<F>
346 where
347 Self: Sized + 'static,
348 {
349 Imag(dyn_clone::clone_box(self))
350 }
351
352 fn prod(als: &Vec<Box<dyn AmpLike<F>>>) -> Product<F>
354 where
355 Self: Sized + 'static,
356 {
357 Product(*dyn_clone::clone_box(als))
358 }
359
360 fn sum(als: &Vec<Box<dyn AmpLike<F>>>) -> Sum<F>
362 where
363 Self: Sized + 'static,
364 {
365 Sum(*dyn_clone::clone_box(als))
366 }
367}
368dyn_clone::clone_trait_object!(<F> AmpLike<F>);
369
370pub trait AsTree {
372 fn get_tree(&self) -> String {
374 self._get_tree(&mut vec![])
375 }
376 fn _get_indent(&self, bits: Vec<bool>) -> String {
380 bits.iter()
381 .map(|b| if *b { " ┃ " } else { " " })
382 .join("")
383 }
384 fn _get_intermediate(&self) -> String {
387 String::from(" ┣━")
388 }
389 fn _get_end(&self) -> String {
391 String::from(" ┗━")
392 }
393 fn _get_tree(&self, bits: &mut Vec<bool>) -> String;
397}
398
399#[derive(Clone)]
405pub struct Amplitude<F: Field> {
406 pub name: String,
408 pub node: Box<dyn Node<F>>,
412 pub active: bool,
414 pub parameters: Vec<String>,
416 pub cache_position: usize,
419 pub parameter_index_start: usize,
422}
423
424impl<F: Field> Debug for Amplitude<F> {
425 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426 write!(f, "{}", self.name)
427 }
428}
429impl<F: Field> Display for Amplitude<F> {
430 #[rustfmt::skip]
431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432 writeln!(f, "Amplitude")?;
433 writeln!(f, " Name: {}", self.name)?;
434 writeln!(f, " Active: {}", self.active)?;
435 writeln!(f, " Cache Position: {}", self.cache_position)?;
436 writeln!(f, " Index of First Parameter: {}", self.parameter_index_start)
437 }
438}
439impl<F: Field> AsTree for Amplitude<F> {
440 fn _get_tree(&self, _bits: &mut Vec<bool>) -> String {
441 let name = if self.active {
442 self.name.clone()
443 } else {
444 format!("/* {} */", self.name)
445 };
446 if self.parameters().len() > 7 {
447 format!(" {}({},...)\n", name, self.parameters()[0..7].join(", "))
448 } else {
449 format!(" {}({})\n", name, self.parameters().join(", "))
450 }
451 }
452}
453impl<F: Field> Amplitude<F> {
454 pub fn new(name: &str, node: impl Node<F> + 'static) -> Self {
456 info!("Created new amplitude named {name}");
457 let parameters = node.parameters();
458 Self {
459 name: name.to_string(),
460 node: Box::new(node),
461 parameters,
462 active: true,
463 cache_position: 0,
464 parameter_index_start: 0,
465 }
466 }
467 pub fn register(
473 &mut self,
474 cache_position: usize,
475 parameter_index_start: usize,
476 dataset: &Dataset<F>,
477 ) -> Result<(), RustitudeError> {
478 self.cache_position = cache_position;
479 self.parameter_index_start = parameter_index_start;
480 self.precalculate(dataset)
481 }
482}
483impl<F: Field> Node<F> for Amplitude<F> {
484 fn precalculate(&mut self, dataset: &Dataset<F>) -> Result<(), RustitudeError> {
485 self.node.precalculate(dataset)?;
486 debug!("Precalculated amplitude {}", self.name);
487 Ok(())
488 }
489 fn calculate(&self, parameters: &[F], event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
490 let res = self.node.calculate(
491 ¶meters
492 [self.parameter_index_start..self.parameter_index_start + self.parameters.len()],
493 event,
494 );
495 debug!(
496 "{}({:?}, event #{}) = {}",
497 self.name,
498 ¶meters
499 [self.parameter_index_start..self.parameter_index_start + self.parameters.len()],
500 event.index,
501 res.as_ref()
502 .map(|c| c.to_string())
503 .unwrap_or_else(|e| e.to_string())
504 );
505 res
506 }
507 fn parameters(&self) -> Vec<String> {
508 self.node.parameters()
509 }
510}
511impl<F: Field> AmpLike<F> for Amplitude<F> {
512 fn walk(&self) -> Vec<Self> {
513 vec![self.clone()]
514 }
515
516 fn walk_mut(&mut self) -> Vec<&mut Self> {
517 vec![self]
518 }
519
520 fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
521 let res = cache[self.cache_position];
522 debug!(
523 "Computing {} from cache: {:?}",
524 self.name,
525 res.as_ref().map(|c| c.to_string())
526 );
527 res
528 }
529}
530
531#[derive(Clone)]
533pub struct Real<F: Field>(Box<dyn AmpLike<F>>);
534impl<F: Field> Debug for Real<F> {
535 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536 write!(f, "Real [ {:?} ]", self.0)
537 }
538}
539impl<F: Field> Display for Real<F> {
540 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541 writeln!(f, "{}", self.get_tree())
542 }
543}
544impl<F: Field> AmpLike<F> for Real<F> {
545 fn walk(&self) -> Vec<Amplitude<F>> {
546 self.0.walk()
547 }
548
549 fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
550 self.0.walk_mut()
551 }
552
553 fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
554 let res: Option<Complex<F>> = self.0.compute(cache).map(|r| r.re.into());
555 debug!(
556 "Computing {:?} from cache: {:?}",
557 self,
558 res.as_ref().map(|c| c.to_string())
559 );
560 res
561 }
562}
563impl<F: Field> AsTree for Real<F> {
564 fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
565 let mut res = String::from("[ real ]\n");
566 res.push_str(&self._get_indent(bits.to_vec()));
567 res.push_str(&self._get_end());
568 bits.push(false);
569 res.push_str(&self.0._get_tree(&mut bits.clone()));
570 bits.pop();
571 res
572 }
573}
574
575#[derive(Clone)]
577pub struct Imag<F: Field>(Box<dyn AmpLike<F>>);
578impl<F: Field> Debug for Imag<F> {
579 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
580 write!(f, "Imag [ {:?} ]", self.0)
581 }
582}
583impl<F: Field> Display for Imag<F> {
584 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
585 writeln!(f, "{}", self.get_tree())
586 }
587}
588impl<F: Field> AmpLike<F> for Imag<F> {
589 fn walk(&self) -> Vec<Amplitude<F>> {
590 self.0.walk()
591 }
592
593 fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
594 self.0.walk_mut()
595 }
596
597 fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
598 let res: Option<Complex<F>> = self.0.compute(cache).map(|r| r.im.into());
599 debug!(
600 "Computing {:?} from cache: {:?}",
601 self,
602 res.as_ref().map(|c| c.to_string())
603 );
604 res
605 }
606}
607impl<F: Field> AsTree for Imag<F> {
608 fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
609 let mut res = String::from("[ imag ]\n");
610 res.push_str(&self._get_indent(bits.to_vec()));
611 res.push_str(&self._get_end());
612 bits.push(false);
613 res.push_str(&self.0._get_tree(&mut bits.clone()));
614 bits.pop();
615 res
616 }
617}
618
619#[derive(Clone)]
621pub struct Product<F: Field>(Vec<Box<dyn AmpLike<F>>>);
622impl<F: Field> Debug for Product<F> {
623 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
624 write!(f, "Product [ ")?;
625 for op in &self.0 {
626 write!(f, "{:?} ", op)?;
627 }
628 write!(f, "]")
629 }
630}
631impl<F: Field> Display for Product<F> {
632 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
633 writeln!(f, "{}", self.get_tree())
634 }
635}
636impl<F: Field> AsTree for Product<F> {
637 fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
638 let mut res = String::from("[ * ]\n");
639 for (i, op) in self.0.iter().enumerate() {
640 res.push_str(&self._get_indent(bits.to_vec()));
641 if i == self.0.len() - 1 {
642 res.push_str(&self._get_end());
643 bits.push(false);
644 } else {
645 res.push_str(&self._get_intermediate());
646 bits.push(true);
647 }
648 res.push_str(&op._get_tree(&mut bits.clone()));
649 bits.pop();
650 }
651 res
652 }
653}
654impl<F: Field> AmpLike<F> for Product<F> {
655 fn get_cloned_terms(&self) -> Option<Vec<Box<dyn AmpLike<F>>>> {
656 Some(self.0.clone())
657 }
658 fn walk(&self) -> Vec<Amplitude<F>> {
659 self.0.iter().flat_map(|op| op.walk()).collect()
660 }
661
662 fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
663 self.0.iter_mut().flat_map(|op| op.walk_mut()).collect()
664 }
665
666 fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
667 let mut values = self.0.iter().filter_map(|op| op.compute(cache)).peekable();
668 let res: Option<Complex<F>> = if values.peek().is_none() {
669 Some(Complex::default())
670 } else {
671 Some(values.product())
672 };
673 debug!(
674 "Computing {:?} from cache: {:?}",
675 self,
676 res.as_ref().map(|c| c.to_string())
677 );
678 res
679 }
680}
681
682#[derive(Clone)]
684pub struct Sum<F: Field>(pub Vec<Box<dyn AmpLike<F>>>);
685impl<F: Field> Debug for Sum<F> {
686 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
687 write!(f, "Sum [ ")?;
688 for op in &self.0 {
689 write!(f, "{:?} ", op)?;
690 }
691 write!(f, "]")
692 }
693}
694impl<F: Field> Display for Sum<F> {
695 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
696 writeln!(f, "{}", self.get_tree())
697 }
698}
699impl<F: Field> AsTree for Sum<F> {
700 fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
701 let mut res = String::from("[ + ]\n");
702 for (i, op) in self.0.iter().enumerate() {
703 res.push_str(&self._get_indent(bits.to_vec()));
704 if i == self.0.len() - 1 {
705 res.push_str(&self._get_end());
706 bits.push(false);
707 } else {
708 res.push_str(&self._get_intermediate());
709 bits.push(true);
710 }
711 res.push_str(&op._get_tree(&mut bits.clone()));
712 bits.pop();
713 }
714 res
715 }
716}
717impl<F: Field> AmpLike<F> for Sum<F> {
718 fn get_cloned_terms(&self) -> Option<Vec<Box<dyn AmpLike<F>>>> {
719 Some(self.0.clone())
720 }
721 fn walk(&self) -> Vec<Amplitude<F>> {
722 self.0.iter().flat_map(|op| op.walk()).collect()
723 }
724
725 fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
726 self.0.iter_mut().flat_map(|op| op.walk_mut()).collect()
727 }
728
729 fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<Complex<F>> {
730 let res = Some(
731 self.0
732 .iter()
733 .filter_map(|al| al.compute(cache))
734 .sum::<Complex<F>>(),
735 );
736 debug!(
737 "Computing {:?} from cache: {:?}",
738 self,
739 res.as_ref().map(|c| c.to_string())
740 );
741 res
742 }
743}
744
745#[derive(Clone)]
747pub struct NormSqr<F: Field>(pub Box<dyn AmpLike<F>>);
748
749impl<F: Field> Debug for NormSqr<F> {
750 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
751 write!(f, "NormSqr[ {:?} ]", self.0)
752 }
753}
754impl<F: Field> Display for NormSqr<F> {
755 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
756 writeln!(f, "{}", self.get_tree())
757 }
758}
759impl<F: Field> AsTree for NormSqr<F> {
760 fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
761 let mut res = String::from("[ |_|^2 ]\n");
762 res.push_str(&self._get_indent(bits.to_vec()));
763 res.push_str(&self._get_end());
764 bits.push(false);
765 res.push_str(&self.0._get_tree(&mut bits.clone()));
766 bits.pop();
767 res
768 }
769}
770impl<F: Field> NormSqr<F> {
771 pub fn compute(&self, cache: &[Option<Complex<F>>]) -> Option<F> {
777 self.0.compute(cache).map(|res| res.norm_sqr())
778 }
779
780 pub fn walk(&self) -> Vec<Amplitude<F>> {
782 self.0.walk()
783 }
784
785 pub fn walk_mut(&mut self) -> Vec<&mut Amplitude<F>> {
788 self.0.walk_mut()
789 }
790}
791
792#[derive(Clone)]
796pub struct Model<F: Field> {
797 pub cohsums: Vec<NormSqr<F>>,
799 pub amplitudes: Arc<RwLock<Vec<Amplitude<F>>>>,
801 pub parameters: Vec<Parameter<F>>,
803 pub contains_python_amplitudes: bool,
806}
807impl<F: Field> Debug for Model<F> {
808 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
809 write!(f, "Model [ ")?;
810 for op in &self.cohsums {
811 write!(f, "{:?} ", op)?;
812 }
813 write!(f, "]")
814 }
815}
816impl<F: Field> Display for Model<F> {
817 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
818 writeln!(f, "{}", self.get_tree())
819 }
820}
821impl<F: Field> AsTree for Model<F> {
822 fn _get_tree(&self, bits: &mut Vec<bool>) -> String {
823 let mut res = String::from("[ + ]\n");
824 for (i, op) in self.cohsums.iter().enumerate() {
825 res.push_str(&self._get_indent(bits.to_vec()));
826 if i == self.cohsums.len() - 1 {
827 res.push_str(&self._get_end());
828 bits.push(false);
829 } else {
830 res.push_str(&self._get_intermediate());
831 bits.push(true);
832 }
833 res.push_str(&op._get_tree(&mut bits.clone()));
834 bits.pop();
835 }
836 res
837 }
838}
839impl<F: Field> Model<F> {
840 pub fn new(amps: &[Box<dyn AmpLike<F>>]) -> Self {
842 let mut amp_names = HashSet::new();
843 let amplitudes: Vec<Amplitude<F>> = amps
844 .iter()
845 .flat_map(|cohsum| cohsum.walk())
846 .filter_map(|amp| {
847 if amp_names.insert(amp.name.clone()) {
848 Some(amp)
849 } else {
850 None
851 }
852 })
853 .collect();
854 let parameter_tags: Vec<(String, String)> = amplitudes
855 .iter()
856 .flat_map(|amp| {
857 amp.parameters()
858 .iter()
859 .map(|p| (amp.name.clone(), p.clone()))
860 .collect::<Vec<_>>()
861 })
862 .collect();
863 let parameters = parameter_tags
864 .iter()
865 .enumerate()
866 .map(|(i, (amp_name, par_name))| Parameter::new(amp_name, par_name, i))
867 .collect();
868 let contains_python_amplitudes = amplitudes.iter().any(|amp| amp.node.is_python_node());
869 Self {
870 cohsums: amps.iter().map(|inner| NormSqr(inner.clone())).collect(),
871 amplitudes: Arc::new(RwLock::new(amplitudes)),
872 parameters,
873 contains_python_amplitudes,
874 }
875 }
876 pub fn deep_clone(&self) -> Self {
879 Self {
880 cohsums: self.cohsums.clone(),
881 amplitudes: Arc::new(RwLock::new(self.amplitudes.read().clone())),
882 parameters: self.parameters.clone(),
883 contains_python_amplitudes: self.contains_python_amplitudes,
884 }
885 }
886 pub fn compute(
894 &self,
895 amplitudes: &[Amplitude<F>],
896 parameters: &[F],
897 event: &Event<F>,
898 ) -> Result<F, RustitudeError> {
899 let cache: Vec<Option<Complex<F>>> = amplitudes
905 .iter()
906 .map(|amp| {
907 if amp.active {
908 amp.calculate(parameters, event).map(Some)
909 } else {
910 Ok(None)
911 }
912 })
913 .collect::<Result<Vec<Option<Complex<F>>>, RustitudeError>>()?;
914 Ok(self
915 .cohsums
916 .iter()
917 .filter_map(|cohsum| cohsum.compute(&cache))
918 .sum::<F>())
919 }
920 pub fn load(&mut self, dataset: &Dataset<F>) -> Result<(), RustitudeError> {
927 let mut next_cache_pos = 0;
928 let mut parameter_index = 0;
929 self.amplitudes.write().iter_mut().try_for_each(|amp| {
930 amp.register(next_cache_pos, parameter_index, dataset)?;
931 self.cohsums.iter_mut().for_each(|cohsum| {
932 cohsum.walk_mut().iter_mut().for_each(|r_amp| {
933 if r_amp.name == amp.name {
934 r_amp.cache_position = next_cache_pos;
935 r_amp.parameter_index_start = parameter_index;
936 }
937 })
938 });
939 next_cache_pos += 1;
940 parameter_index += amp.parameters().len();
941 Ok(())
942 })
943 }
944
945 pub fn get_amplitude(&self, amplitude_name: &str) -> Result<Amplitude<F>, RustitudeError> {
950 self.amplitudes
951 .read()
952 .iter()
953 .find(|a: &&Amplitude<F>| a.name == amplitude_name)
954 .ok_or_else(|| RustitudeError::AmplitudeNotFoundError(amplitude_name.to_string()))
955 .cloned()
956 }
957 pub fn get_parameter(
963 &self,
964 amplitude_name: &str,
965 parameter_name: &str,
966 ) -> Result<Parameter<F>, RustitudeError> {
967 self.get_amplitude(amplitude_name)?;
968 self.parameters
969 .iter()
970 .find(|p: &&Parameter<F>| p.amplitude == amplitude_name && p.name == parameter_name)
971 .ok_or_else(|| RustitudeError::ParameterNotFoundError(parameter_name.to_string()))
972 .cloned()
973 }
974 pub fn print_parameters(&self) {
976 let any_fixed = if self.any_fixed() { 1 } else { 0 };
977 if self.any_fixed() {
978 println!(
979 "Fixed: {}",
980 self.group_by_index()[0]
981 .iter()
982 .map(|p| format!("{:?}", p))
983 .join(", ")
984 );
985 }
986 for (i, group) in self.group_by_index().iter().skip(any_fixed).enumerate() {
987 println!(
988 "{}: {}",
989 i,
990 group.iter().map(|p| format!("{:?}", p)).join(", ")
991 );
992 }
993 }
994
995 pub fn free_parameters(&self) -> Vec<Parameter<F>> {
997 self.parameters
998 .iter()
999 .filter(|p| p.is_free())
1000 .cloned()
1001 .collect()
1002 }
1003
1004 pub fn fixed_parameters(&self) -> Vec<Parameter<F>> {
1006 self.parameters
1007 .iter()
1008 .filter(|p| p.is_fixed())
1009 .cloned()
1010 .collect()
1011 }
1012
1013 pub fn constrain(
1020 &mut self,
1021 amplitude_1: &str,
1022 parameter_1: &str,
1023 amplitude_2: &str,
1024 parameter_2: &str,
1025 ) -> Result<(), RustitudeError> {
1026 let p1 = self.get_parameter(amplitude_1, parameter_1)?;
1027 let p2 = self.get_parameter(amplitude_2, parameter_2)?;
1028 for par in self.parameters.iter_mut() {
1029 match p1.index.cmp(&p2.index) {
1031 std::cmp::Ordering::Less => {
1033 if par.index == p2.index {
1034 par.index = p1.index;
1035 par.initial = p1.initial;
1036 par.fixed_index = p1.fixed_index;
1037 }
1038 }
1039 std::cmp::Ordering::Equal => unimplemented!(),
1040 std::cmp::Ordering::Greater => {
1042 if par.index == p1.index {
1043 par.index = p2.index;
1044 par.initial = p2.initial;
1045 par.fixed_index = p2.fixed_index;
1046 }
1047 }
1048 }
1049 }
1050 self.reindex_parameters();
1051 Ok(())
1052 }
1053
1054 pub fn fix(
1063 &mut self,
1064 amplitude: &str,
1065 parameter: &str,
1066 value: F,
1067 ) -> Result<(), RustitudeError> {
1068 let search_par = self.get_parameter(amplitude, parameter)?;
1069 let fixed_index = self.get_min_fixed_index();
1070 for par in self.parameters.iter_mut() {
1071 if par.index == search_par.index {
1072 par.index = None;
1073 par.initial = value;
1074 par.fixed_index = fixed_index;
1075 }
1076 }
1077 self.reindex_parameters();
1078 Ok(())
1079 }
1080 pub fn free(&mut self, amplitude: &str, parameter: &str) -> Result<(), RustitudeError> {
1089 let search_par = self.get_parameter(amplitude, parameter)?;
1090 let index = self.get_min_free_index();
1091 for par in self.parameters.iter_mut() {
1092 if par.fixed_index == search_par.fixed_index {
1093 par.index = index;
1094 par.fixed_index = None;
1095 }
1096 }
1097 self.reindex_parameters();
1098 Ok(())
1099 }
1100 pub fn set_bounds(
1106 &mut self,
1107 amplitude: &str,
1108 parameter: &str,
1109 bounds: (F, F),
1110 ) -> Result<(), RustitudeError> {
1111 let search_par = self.get_parameter(amplitude, parameter)?;
1112 if search_par.index.is_some() {
1113 for par in self.parameters.iter_mut() {
1114 if par.index == search_par.index {
1115 par.bounds = bounds;
1116 }
1117 }
1118 } else {
1119 for par in self.parameters.iter_mut() {
1120 if par.fixed_index == search_par.fixed_index {
1121 par.bounds = bounds;
1122 }
1123 }
1124 }
1125 Ok(())
1126 }
1127 pub fn set_initial(
1133 &mut self,
1134 amplitude: &str,
1135 parameter: &str,
1136 initial: F,
1137 ) -> Result<(), RustitudeError> {
1138 let search_par = self.get_parameter(amplitude, parameter)?;
1139 if search_par.index.is_some() {
1140 for par in self.parameters.iter_mut() {
1141 if par.index == search_par.index {
1142 par.initial = initial;
1143 }
1144 }
1145 } else {
1146 for par in self.parameters.iter_mut() {
1147 if par.fixed_index == search_par.fixed_index {
1148 par.initial = initial;
1149 }
1150 }
1151 }
1152 Ok(())
1153 }
1154 pub fn get_bounds(&self) -> Vec<(F, F)> {
1156 let any_fixed = if self.any_fixed() { 1 } else { 0 };
1157 self.group_by_index()
1158 .iter()
1159 .skip(any_fixed)
1160 .filter_map(|group| group.first().map(|par| par.bounds))
1161 .collect()
1162 }
1163 pub fn get_initial(&self) -> Vec<F> {
1165 let any_fixed = if self.any_fixed() { 1 } else { 0 };
1166 self.group_by_index()
1167 .iter()
1168 .skip(any_fixed)
1169 .filter_map(|group| group.first().map(|par| par.initial))
1170 .collect()
1171 }
1172 pub fn get_n_free(&self) -> usize {
1174 self.get_min_free_index().unwrap_or(0)
1175 }
1176 pub fn activate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
1183 if !self.amplitudes.read().iter().any(|a| a.name == amplitude) {
1184 return Err(RustitudeError::AmplitudeNotFoundError(
1185 amplitude.to_string(),
1186 ));
1187 }
1188 self.amplitudes.write().iter_mut().for_each(|amp| {
1189 if amp.name == amplitude {
1190 amp.active = true
1191 }
1192 });
1193 self.cohsums.iter_mut().for_each(|cohsum| {
1194 cohsum.walk_mut().iter_mut().for_each(|amp| {
1195 if amp.name == amplitude {
1196 amp.active = true
1197 }
1198 })
1199 });
1200 Ok(())
1201 }
1202 pub fn activate_all(&mut self) {
1204 self.amplitudes
1205 .write()
1206 .iter_mut()
1207 .for_each(|amp| amp.active = true);
1208 self.cohsums.iter_mut().for_each(|cohsum| {
1209 cohsum
1210 .walk_mut()
1211 .iter_mut()
1212 .for_each(|amp| amp.active = true)
1213 });
1214 }
1215 pub fn isolate(&mut self, amplitudes: Vec<&str>) -> Result<(), RustitudeError> {
1222 self.deactivate_all();
1223 for amplitude in amplitudes {
1224 self.activate(amplitude)?;
1225 }
1226 Ok(())
1227 }
1228 pub fn deactivate(&mut self, amplitude: &str) -> Result<(), RustitudeError> {
1235 if !self.amplitudes.read().iter().any(|a| a.name == amplitude) {
1236 return Err(RustitudeError::AmplitudeNotFoundError(
1237 amplitude.to_string(),
1238 ));
1239 }
1240 self.amplitudes.write().iter_mut().for_each(|amp| {
1241 if amp.name == amplitude {
1242 amp.active = false
1243 }
1244 });
1245 self.cohsums.iter_mut().for_each(|cohsum| {
1246 cohsum.walk_mut().iter_mut().for_each(|amp| {
1247 if amp.name == amplitude {
1248 amp.active = false
1249 }
1250 })
1251 });
1252 Ok(())
1253 }
1254 pub fn deactivate_all(&mut self) {
1256 self.amplitudes
1257 .write()
1258 .iter_mut()
1259 .for_each(|amp| amp.active = false);
1260 self.cohsums.iter_mut().for_each(|cohsum| {
1261 cohsum
1262 .walk_mut()
1263 .iter_mut()
1264 .for_each(|amp| amp.active = false)
1265 });
1266 }
1267 fn group_by_index(&self) -> Vec<Vec<&Parameter<F>>> {
1268 self.parameters
1269 .iter()
1270 .sorted_by_key(|par| par.index)
1271 .chunk_by(|par| par.index)
1272 .into_iter()
1273 .map(|(_, group)| group.collect::<Vec<_>>())
1274 .collect()
1275 }
1276 fn group_by_index_mut(&mut self) -> Vec<Vec<&mut Parameter<F>>> {
1277 self.parameters
1278 .iter_mut()
1279 .sorted_by_key(|par| par.index)
1280 .chunk_by(|par| par.index)
1281 .into_iter()
1282 .map(|(_, group)| group.collect())
1283 .collect()
1284 }
1285 fn any_fixed(&self) -> bool {
1286 self.parameters.iter().any(|p| p.index.is_none())
1287 }
1288 fn reindex_parameters(&mut self) {
1289 let any_fixed = if self.any_fixed() { 1 } else { 0 };
1290 self.group_by_index_mut()
1291 .iter_mut()
1292 .skip(any_fixed) .enumerate()
1294 .for_each(|(ind, par_group)| par_group.iter_mut().for_each(|par| par.index = Some(ind)))
1295 }
1296 fn get_min_free_index(&self) -> Option<usize> {
1297 self.parameters
1298 .iter()
1299 .filter_map(|p| p.index)
1300 .max()
1301 .map_or(Some(0), |max| Some(max + 1))
1302 }
1303 fn get_min_fixed_index(&self) -> Option<usize> {
1304 self.parameters
1305 .iter()
1306 .filter_map(|p| p.fixed_index)
1307 .max()
1308 .map_or(Some(0), |max| Some(max + 1))
1309 }
1310}
1311
1312#[derive(Clone)]
1320pub struct Scalar;
1321impl<F: Field> Node<F> for Scalar {
1322 fn parameters(&self) -> Vec<String> {
1323 vec!["value".to_string()]
1324 }
1325 fn calculate(&self, parameters: &[F], _event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
1326 Ok(Complex::new(parameters[0], F::zero()))
1327 }
1328}
1329
1330pub fn scalar<F: Field>(name: &str) -> Amplitude<F> {
1345 Amplitude::new(name, Scalar)
1346}
1347#[derive(Clone)]
1357pub struct ComplexScalar;
1358impl<F: Field> Node<F> for ComplexScalar {
1359 fn calculate(&self, parameters: &[F], _event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
1360 Ok(Complex::new(parameters[0], parameters[1]))
1361 }
1362
1363 fn parameters(&self) -> Vec<String> {
1364 vec!["real".to_string(), "imag".to_string()]
1365 }
1366}
1367pub fn cscalar<F: Field>(name: &str) -> Amplitude<F> {
1382 Amplitude::new(name, ComplexScalar)
1383}
1384
1385#[derive(Clone)]
1395pub struct PolarComplexScalar;
1396impl<F: Field> Node<F> for PolarComplexScalar {
1397 fn calculate(&self, parameters: &[F], _event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
1398 Ok(Complex::cis(parameters[1]).mul(parameters[0]))
1399 }
1400
1401 fn parameters(&self) -> Vec<String> {
1402 vec!["mag".to_string(), "phi".to_string()]
1403 }
1404}
1405
1406pub fn pcscalar<F: Field>(name: &str) -> Amplitude<F> {
1421 Amplitude::new(name, PolarComplexScalar)
1422}
1423
1424#[derive(Clone)]
1426pub struct Piecewise<V, F>
1427where
1428 V: Fn(&Event<F>) -> F + Send + Sync + Copy,
1429 F: Field,
1430{
1431 edges: Vec<(F, F)>,
1432 variable: V,
1433 calculated_variable: Vec<F>,
1434}
1435
1436impl<V, F> Piecewise<V, F>
1437where
1438 V: Fn(&Event<F>) -> F + Send + Sync + Copy,
1439 F: Field,
1440{
1441 pub fn new(bins: usize, range: (F, F), variable: V) -> Self {
1444 let diff = (range.1 - range.0) / convert!(bins, F);
1445 let edges = (0..bins)
1446 .map(|i| {
1447 (
1448 F::mul_add(convert!(i, F), diff, range.0),
1449 F::mul_add(convert!(i + 1, F), diff, range.0),
1450 )
1451 })
1452 .collect();
1453 Self {
1454 edges,
1455 variable,
1456 calculated_variable: Vec::default(),
1457 }
1458 }
1459}
1460
1461impl<V, F> Node<F> for Piecewise<V, F>
1462where
1463 V: Fn(&Event<F>) -> F + Send + Sync + Copy,
1464 F: Field,
1465{
1466 fn precalculate(&mut self, dataset: &Dataset<F>) -> Result<(), RustitudeError> {
1467 self.calculated_variable = dataset.events.par_iter().map(self.variable).collect();
1468 Ok(())
1469 }
1470
1471 fn calculate(&self, parameters: &[F], event: &Event<F>) -> Result<Complex<F>, RustitudeError> {
1472 let val = self.calculated_variable[event.index];
1473 let opt_i_bin = self.edges.iter().position(|&(l, r)| val >= l && val <= r);
1474 opt_i_bin.map_or_else(
1475 || Ok(Complex::default()),
1476 |i_bin| {
1477 Ok(Complex::new(
1478 parameters[i_bin * 2],
1479 parameters[(i_bin * 2) + 1],
1480 ))
1481 },
1482 )
1483 }
1484
1485 fn parameters(&self) -> Vec<String> {
1486 (0..self.edges.len())
1487 .flat_map(|i| vec![format!("bin {} re", i), format!("bin {} im", i)])
1488 .collect()
1489 }
1490}
1491
1492pub fn piecewise_m<F: Field + 'static>(name: &str, bins: usize, range: (F, F)) -> Amplitude<F> {
1493 Amplitude::new(
1495 name,
1496 Piecewise::new(bins, range, |e: &Event<F>| {
1497 (e.daughter_p4s[0] + e.daughter_p4s[1]).m()
1498 }),
1499 )
1500}
1501
1502macro_rules! impl_sum {
1503 ($t:ident, $a:ty, $b:ty) => {
1504 impl<$t: Field + 'static> Add<$b> for $a {
1505 type Output = Sum<$t>;
1506
1507 fn add(self, rhs: $b) -> Self::Output {
1508 Sum(vec![Box::new(self), Box::new(rhs)])
1509 }
1510 }
1511
1512 impl<$t: Field + 'static> Add<&$b> for &$a {
1513 type Output = <$a as Add<$b>>::Output;
1514
1515 fn add(self, rhs: &$b) -> Self::Output {
1516 <$a as Add<$b>>::add(self.clone(), rhs.clone())
1517 }
1518 }
1519
1520 impl<$t: Field + 'static> Add<&$b> for $a {
1521 type Output = <$a as Add<$b>>::Output;
1522
1523 fn add(self, rhs: &$b) -> Self::Output {
1524 <$a as Add<$b>>::add(self, rhs.clone())
1525 }
1526 }
1527
1528 impl<$t: Field + 'static> Add<$b> for &$a {
1529 type Output = <$a as Add<$b>>::Output;
1530
1531 fn add(self, rhs: $b) -> Self::Output {
1532 <$a as Add<$b>>::add(self.clone(), rhs)
1533 }
1534 }
1535
1536 impl<$t: Field + 'static> Add<$a> for $b {
1537 type Output = Sum<$t>;
1538
1539 fn add(self, rhs: $a) -> Self::Output {
1540 Sum(vec![Box::new(self), Box::new(rhs)])
1541 }
1542 }
1543
1544 impl<$t: Field + 'static> Add<&$a> for &$b {
1545 type Output = <$b as Add<$a>>::Output;
1546
1547 fn add(self, rhs: &$a) -> Self::Output {
1548 <$b as Add<$a>>::add(self.clone(), rhs.clone())
1549 }
1550 }
1551
1552 impl<$t: Field + 'static> Add<&$a> for $b {
1553 type Output = <$b as Add<$a>>::Output;
1554
1555 fn add(self, rhs: &$a) -> Self::Output {
1556 <$b as Add<$a>>::add(self, rhs.clone())
1557 }
1558 }
1559
1560 impl<$t: Field + 'static> Add<$a> for &$b {
1561 type Output = <$b as Add<$a>>::Output;
1562
1563 fn add(self, rhs: $a) -> Self::Output {
1564 <$b as Add<$a>>::add(self.clone(), rhs)
1565 }
1566 }
1567 };
1568 ($t:ident, $a:ty) => {
1569 impl<$t: Field + 'static> Add<$a> for $a {
1570 type Output = Sum<$t>;
1571
1572 fn add(self, rhs: $a) -> Self::Output {
1573 Sum(vec![Box::new(self), Box::new(rhs)])
1574 }
1575 }
1576
1577 impl<$t: Field + 'static> Add<&$a> for &$a {
1578 type Output = <$a as Add<$a>>::Output;
1579
1580 fn add(self, rhs: &$a) -> Self::Output {
1581 <$a as Add<$a>>::add(self.clone(), rhs.clone())
1582 }
1583 }
1584
1585 impl<$t: Field + 'static> Add<&$a> for $a {
1586 type Output = <$a as Add<$a>>::Output;
1587
1588 fn add(self, rhs: &$a) -> Self::Output {
1589 <$a as Add<$a>>::add(self, rhs.clone())
1590 }
1591 }
1592
1593 impl<$t: Field + 'static> Add<$a> for &$a {
1594 type Output = <$a as Add<$a>>::Output;
1595
1596 fn add(self, rhs: $a) -> Self::Output {
1597 <$a as Add<$a>>::add(self.clone(), rhs)
1598 }
1599 }
1600 };
1601}
1602macro_rules! impl_appending_sum {
1603 ($t:ident, $a:ty) => {
1604 impl<$t: Field + 'static> Add<Sum<$t>> for $a {
1605 type Output = Sum<$t>;
1606
1607 fn add(self, rhs: Sum<$t>) -> Self::Output {
1608 let mut terms = rhs.0;
1609 terms.insert(0, Box::new(self));
1610 Sum(terms)
1611 }
1612 }
1613
1614 impl<$t: Field + 'static> Add<$a> for Sum<$t> {
1615 type Output = Sum<$t>;
1616
1617 fn add(self, rhs: $a) -> Self::Output {
1618 let mut terms = self.0;
1619 terms.push(Box::new(rhs));
1620 Sum(terms)
1621 }
1622 }
1623
1624 impl<$t: Field + 'static> Add<&Sum<$t>> for &$a {
1625 type Output = <$a as Add<Sum<$t>>>::Output;
1626
1627 fn add(self, rhs: &Sum<$t>) -> Self::Output {
1628 <$a as Add<Sum<$t>>>::add(self.clone(), rhs.clone())
1629 }
1630 }
1631
1632 impl<$t: Field + 'static> Add<&Sum<$t>> for $a {
1633 type Output = <$a as Add<Sum<$t>>>::Output;
1634
1635 fn add(self, rhs: &Sum<$t>) -> Self::Output {
1636 <$a as Add<Sum<$t>>>::add(self, rhs.clone())
1637 }
1638 }
1639
1640 impl<$t: Field + 'static> Add<Sum<$t>> for &$a {
1641 type Output = <$a as Add<Sum<$t>>>::Output;
1642
1643 fn add(self, rhs: Sum<$t>) -> Self::Output {
1644 <$a as Add<Sum<$t>>>::add(self.clone(), rhs)
1645 }
1646 }
1647
1648 impl<$t: Field + 'static> Add<&$a> for &Sum<$t> {
1649 type Output = <Sum<$t> as Add<$a>>::Output;
1650
1651 fn add(self, rhs: &$a) -> Self::Output {
1652 <Sum<$t> as Add<$a>>::add(self.clone(), rhs.clone())
1653 }
1654 }
1655
1656 impl<$t: Field + 'static> Add<&$a> for Sum<$t> {
1657 type Output = <Sum<$t> as Add<$a>>::Output;
1658
1659 fn add(self, rhs: &$a) -> Self::Output {
1660 <Sum<$t> as Add<$a>>::add(self, rhs.clone())
1661 }
1662 }
1663
1664 impl<$t: Field + 'static> Add<$a> for &Sum<$t> {
1665 type Output = <Sum<$t> as Add<$a>>::Output;
1666
1667 fn add(self, rhs: $a) -> Self::Output {
1668 <Sum<$t> as Add<$a>>::add(self.clone(), rhs)
1669 }
1670 }
1671 };
1672}
1673macro_rules! impl_prod {
1674 ($t:ident, $a:ty, $b:ty) => {
1675 impl<$t: Field + 'static> Mul<$b> for $a {
1676 type Output = Product<$t>;
1677
1678 fn mul(self, rhs: $b) -> Self::Output {
1679 match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1680 (Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
1681 (None, Some(terms)) => {
1682 let mut terms = terms;
1683 terms.insert(0, Box::new(self));
1684 Product(terms)
1685 }
1686 (Some(terms), None) => {
1687 let mut terms = terms;
1688 terms.push(Box::new(rhs));
1689 Product(terms)
1690 }
1691 (None, None) => Product(vec![Box::new(self), Box::new(rhs)]),
1692 }
1693 }
1694 }
1695
1696 impl<$t: Field + 'static> Mul<&$b> for &$a {
1697 type Output = <$a as Mul<$b>>::Output;
1698
1699 fn mul(self, rhs: &$b) -> Self::Output {
1700 <$a as Mul<$b>>::mul(self.clone(), rhs.clone())
1701 }
1702 }
1703
1704 impl<$t: Field + 'static> Mul<&$b> for $a {
1705 type Output = <$a as Mul<$b>>::Output;
1706
1707 fn mul(self, rhs: &$b) -> Self::Output {
1708 <$a as Mul<$b>>::mul(self, rhs.clone())
1709 }
1710 }
1711
1712 impl<$t: Field + 'static> Mul<$b> for &$a {
1713 type Output = <$a as Mul<$b>>::Output;
1714
1715 fn mul(self, rhs: $b) -> Self::Output {
1716 <$a as Mul<$b>>::mul(self.clone(), rhs)
1717 }
1718 }
1719
1720 impl<$t: Field + 'static> Mul<$a> for $b {
1721 type Output = Product<$t>;
1722
1723 fn mul(self, rhs: $a) -> Self::Output {
1724 match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1725 (Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
1726 (None, Some(terms)) => {
1727 let mut terms = terms;
1728 terms.insert(0, Box::new(self));
1729 Product(terms)
1730 }
1731 (Some(terms), None) => {
1732 let mut terms = terms;
1733 terms.push(Box::new(rhs));
1734 Product(terms)
1735 }
1736 (None, None) => Product(vec![Box::new(self), Box::new(rhs)]),
1737 }
1738 }
1739 }
1740
1741 impl<$t: Field + 'static> Mul<&$a> for &$b {
1742 type Output = <$b as Mul<$a>>::Output;
1743
1744 fn mul(self, rhs: &$a) -> Self::Output {
1745 <$b as Mul<$a>>::mul(self.clone(), rhs.clone())
1746 }
1747 }
1748
1749 impl<$t: Field + 'static> Mul<&$a> for $b {
1750 type Output = <$b as Mul<$a>>::Output;
1751
1752 fn mul(self, rhs: &$a) -> Self::Output {
1753 <$b as Mul<$a>>::mul(self, rhs.clone())
1754 }
1755 }
1756
1757 impl<$t: Field + 'static> Mul<$a> for &$b {
1758 type Output = <$b as Mul<$a>>::Output;
1759
1760 fn mul(self, rhs: $a) -> Self::Output {
1761 <$b as Mul<$a>>::mul(self.clone(), rhs)
1762 }
1763 }
1764 };
1765 ($t:ident, $a:ty) => {
1766 impl<$t: Field + 'static> Mul<$a> for $a {
1767 type Output = Product<$t>;
1768
1769 fn mul(self, rhs: $a) -> Self::Output {
1770 match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1771 (Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
1772 (None, Some(terms)) => {
1773 let mut terms = terms;
1774 terms.insert(0, Box::new(self));
1775 Product(terms)
1776 }
1777 (Some(terms), None) => {
1778 let mut terms = terms;
1779 terms.push(Box::new(rhs));
1780 Product(terms)
1781 }
1782 (None, None) => Product(vec![Box::new(self), Box::new(rhs)]),
1783 }
1784 }
1785 }
1786
1787 impl<$t: Field + 'static> Mul<&$a> for &$a {
1788 type Output = <$a as Mul<$a>>::Output;
1789
1790 fn mul(self, rhs: &$a) -> Self::Output {
1791 <$a as Mul<$a>>::mul(self.clone(), rhs.clone())
1792 }
1793 }
1794
1795 impl<$t: Field + 'static> Mul<&$a> for $a {
1796 type Output = <$a as Mul<$a>>::Output;
1797
1798 fn mul(self, rhs: &$a) -> Self::Output {
1799 <$a as Mul<$a>>::mul(self, rhs.clone())
1800 }
1801 }
1802
1803 impl<$t: Field + 'static> Mul<$a> for &$a {
1804 type Output = <$a as Mul<$a>>::Output;
1805
1806 fn mul(self, rhs: $a) -> Self::Output {
1807 <$a as Mul<$a>>::mul(self.clone(), rhs)
1808 }
1809 }
1810 };
1811}
1812macro_rules! impl_box_prod {
1813 ($t:ident, $a:ty) => {
1814 impl<$t: Field + 'static> Mul<Box<dyn AmpLike<$t>>> for $a {
1815 type Output = Product<$t>;
1816 fn mul(self, rhs: Box<dyn AmpLike<$t>>) -> Self::Output {
1817 match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1818 (Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
1819 (None, Some(terms)) => {
1820 let mut terms = terms;
1821 terms.insert(0, Box::new(self));
1822 Product(terms)
1823 }
1824 (Some(terms), None) => {
1825 let mut terms = terms;
1826 terms.push(Box::new(self));
1827 Product(terms)
1828 }
1829 (None, None) => Product(vec![Box::new(self), rhs]),
1830 }
1831 }
1832 }
1833 impl<$t: Field + 'static> Mul<$a> for Box<dyn AmpLike<$t>> {
1834 type Output = Product<$t>;
1835 fn mul(self, rhs: $a) -> Self::Output {
1836 match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1837 (Some(terms_a), Some(terms_b)) => Product([terms_a, terms_b].concat()),
1838 (None, Some(terms)) => {
1839 let mut terms = terms;
1840 terms.insert(0, self);
1841 Product(terms)
1842 }
1843 (Some(terms), None) => {
1844 let mut terms = terms;
1845 terms.push(self);
1846 Product(terms)
1847 }
1848 (None, None) => Product(vec![self, Box::new(rhs)]),
1849 }
1850 }
1851 }
1852 };
1853}
1854macro_rules! impl_box_sum {
1855 ($t:ident, $a:ty) => {
1856 impl<$t: Field + 'static> Add<Box<dyn AmpLike<$t>>> for $a {
1857 type Output = Sum<$t>;
1858 fn add(self, rhs: Box<dyn AmpLike<$t>>) -> Self::Output {
1859 match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1860 (Some(terms_a), Some(terms_b)) => Sum([terms_a, terms_b].concat()),
1861 (None, Some(terms)) => {
1862 let mut terms = terms;
1863 terms.insert(0, Box::new(self));
1864 Sum(terms)
1865 }
1866 (Some(terms), None) => {
1867 let mut terms = terms;
1868 terms.push(Box::new(self));
1869 Sum(terms)
1870 }
1871 (None, None) => Sum(vec![Box::new(self), rhs]),
1872 }
1873 }
1874 }
1875 impl<$t: Field + 'static> Add<$a> for Box<dyn AmpLike<$t>> {
1876 type Output = Sum<$t>;
1877 fn add(self, rhs: $a) -> Self::Output {
1878 match (self.get_cloned_terms(), rhs.get_cloned_terms()) {
1879 (Some(terms_a), Some(terms_b)) => Sum([terms_a, terms_b].concat()),
1880 (None, Some(terms)) => {
1881 let mut terms = terms;
1882 terms.insert(0, self);
1883 Sum(terms)
1884 }
1885 (Some(terms), None) => {
1886 let mut terms = terms;
1887 terms.push(self);
1888 Sum(terms)
1889 }
1890 (None, None) => Sum(vec![self, Box::new(rhs)]),
1891 }
1892 }
1893 }
1894 };
1895}
1896macro_rules! impl_dist {
1897 ($t:ident, $a:ty) => {
1898 impl<$t: Field + 'static> Mul<Sum<$t>> for $a {
1899 type Output = Sum<$t>;
1900
1901 fn mul(self, rhs: Sum<$t>) -> Self::Output {
1902 let mut terms = vec![];
1903 for term in rhs.0 {
1904 terms.push(Box::new(self.clone() * term) as Box<dyn AmpLike<$t>>);
1905 }
1906 Sum(terms)
1907 }
1908 }
1909
1910 impl<$t: Field + 'static> Mul<$a> for Sum<$t> {
1911 type Output = Sum<$t>;
1912
1913 fn mul(self, rhs: $a) -> Self::Output {
1914 let mut terms = vec![];
1915 for term in self.0 {
1916 terms.push(Box::new(term * rhs.clone()) as Box<dyn AmpLike<$t>>);
1917 }
1918 Sum(terms)
1919 }
1920 }
1921
1922 impl<$t: Field + 'static> Mul<&$a> for &Sum<$t> {
1923 type Output = <Sum<$t> as Mul<$a>>::Output;
1924
1925 fn mul(self, rhs: &$a) -> Self::Output {
1926 <Sum<$t> as Mul<$a>>::mul(self.clone(), rhs.clone())
1927 }
1928 }
1929
1930 impl<$t: Field + 'static> Mul<&$a> for Sum<$t> {
1931 type Output = <Sum<$t> as Mul<$a>>::Output;
1932
1933 fn mul(self, rhs: &$a) -> Self::Output {
1934 <Sum<$t> as Mul<$a>>::mul(self, rhs.clone())
1935 }
1936 }
1937
1938 impl<$t: Field + 'static> Mul<$a> for &Sum<$t> {
1939 type Output = <Sum<$t> as Mul<$a>>::Output;
1940
1941 fn mul(self, rhs: $a) -> Self::Output {
1942 <Sum<$t> as Mul<$a>>::mul(self.clone(), rhs)
1943 }
1944 }
1945
1946 impl<$t: Field + 'static> Mul<&Sum<$t>> for &$a {
1947 type Output = <$a as Mul<Sum<$t>>>::Output;
1948
1949 fn mul(self, rhs: &Sum<$t>) -> Self::Output {
1950 <$a as Mul<Sum<$t>>>::mul(self.clone(), rhs.clone())
1951 }
1952 }
1953
1954 impl<$t: Field + 'static> Mul<&Sum<$t>> for $a {
1955 type Output = <$a as Mul<Sum<$t>>>::Output;
1956
1957 fn mul(self, rhs: &Sum<$t>) -> Self::Output {
1958 <$a as Mul<Sum<$t>>>::mul(self, rhs.clone())
1959 }
1960 }
1961
1962 impl<$t: Field + 'static> Mul<Sum<$t>> for &$a {
1963 type Output = <$a as Mul<Sum<$t>>>::Output;
1964
1965 fn mul(self, rhs: Sum<$t>) -> Self::Output {
1966 <$a as Mul<Sum<$t>>>::mul(self.clone(), rhs)
1967 }
1968 }
1969 };
1970}
1971
1972impl_sum!(F, Amplitude<F>);
1973impl_box_sum!(F, Amplitude<F>);
1974impl_sum!(F, Real<F>);
1975impl_box_sum!(F, Real<F>);
1976impl_sum!(F, Imag<F>);
1977impl_box_sum!(F, Imag<F>);
1978impl_sum!(F, Product<F>);
1979impl_box_sum!(F, Product<F>);
1980impl_box_sum!(F, Sum<F>);
1981
1982impl_sum!(F, Amplitude<F>, Real<F>);
1983impl_sum!(F, Amplitude<F>, Imag<F>);
1984impl_sum!(F, Amplitude<F>, Product<F>);
1985impl_sum!(F, Real<F>, Imag<F>);
1986impl_sum!(F, Real<F>, Product<F>);
1987impl_sum!(F, Imag<F>, Product<F>);
1988
1989impl_appending_sum!(F, Amplitude<F>);
1990impl_appending_sum!(F, Real<F>);
1991impl_appending_sum!(F, Imag<F>);
1992impl_appending_sum!(F, Product<F>);
1993
1994impl_prod!(F, Amplitude<F>);
1995impl_box_prod!(F, Amplitude<F>);
1996impl_prod!(F, Real<F>);
1997impl_box_prod!(F, Real<F>);
1998impl_prod!(F, Imag<F>);
1999impl_box_prod!(F, Imag<F>);
2000impl_prod!(F, Product<F>);
2001impl_box_prod!(F, Product<F>);
2002
2003impl_prod!(F, Amplitude<F>, Real<F>);
2004impl_prod!(F, Amplitude<F>, Imag<F>);
2005impl_prod!(F, Amplitude<F>, Product<F>);
2006impl_prod!(F, Real<F>, Imag<F>);
2007impl_prod!(F, Real<F>, Product<F>);
2008impl_prod!(F, Imag<F>, Product<F>);
2009
2010impl_dist!(F, Amplitude<F>);
2011impl_dist!(F, Real<F>);
2012impl_dist!(F, Imag<F>);
2013impl_dist!(F, Product<F>);
2014
2015impl<F: Field> Add<Self> for Sum<F> {
2016 type Output = Self;
2017
2018 fn add(self, rhs: Self) -> Self::Output {
2019 Self([self.0, rhs.0].concat())
2020 }
2021}
2022
2023impl<F: Field> Add<&Sum<F>> for &Sum<F> {
2024 type Output = <Sum<F> as Add<Sum<F>>>::Output;
2025
2026 fn add(self, rhs: &Sum<F>) -> Self::Output {
2027 <Sum<F> as Add<Sum<F>>>::add(self.clone(), rhs.clone())
2028 }
2029}
2030
2031impl<F: Field> Add<&Self> for Sum<F> {
2032 type Output = <Self as Add<Self>>::Output;
2033
2034 fn add(self, rhs: &Self) -> Self::Output {
2035 <Self as Add<Self>>::add(self, rhs.clone())
2036 }
2037}
2038
2039impl<F: Field> Add<Sum<F>> for &Sum<F> {
2040 type Output = <Sum<F> as Add<Sum<F>>>::Output;
2041
2042 fn add(self, rhs: Sum<F>) -> Self::Output {
2043 <Sum<F> as Add<Sum<F>>>::add(self.clone(), rhs)
2044 }
2045}