1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use crate::generator::{error_non_generator, Axis, Coefs, GenOp, Generator};
use crate::operations::helpers::handle_id_error;
use crate::{GetLengthRatio, NormalForm, Term};
use num_integer::lcm;
use num_rational::Rational64;
use rand::{rngs::StdRng, SeedableRng};
use scop::Defs;
use weresocool_error::Error;

impl GetLengthRatio<Term> for GenOp {
    fn get_length_ratio(
        &self,
        normal_form: &NormalForm,
        defs: &mut Defs<Term>,
    ) -> Result<Rational64, Error> {
        match self {
            GenOp::Named { name, seed } => {
                let generator = handle_id_error(name, defs)?;
                match generator {
                    Term::Gen(mut gen) => {
                        gen.set_seed(*seed);
                        gen.get_length_ratio_genop(None, normal_form, defs)
                    }
                    _ => Err(error_non_generator()),
                }
            }
            GenOp::Const { gen, seed } => {
                let n = gen.lcm_length();
                Ok(gen.get_length(n, *seed, normal_form, defs)?)
            }
            GenOp::Taken { gen, n, seed } => {
                let mut gen = gen.to_owned();
                gen.set_seed(*seed);
                gen.get_length_ratio_genop(Some(*n), normal_form, defs)
            }
        }
    }
}

impl GenOp {
    pub fn length(&self, defs: &mut Defs<Term>) -> Result<usize, Error> {
        match self {
            GenOp::Named { name, seed } => {
                let generator = handle_id_error(name, defs)?;
                match generator {
                    Term::Gen(mut gen) => {
                        gen.set_seed(*seed);
                        gen.length(defs)
                    }
                    _ => Err(error_non_generator()),
                }
            }
            GenOp::Const { gen, .. } => Ok(gen.lcm_length()),
            GenOp::Taken { n, .. } => Ok(*n),
        }
    }

    pub fn get_length_ratio_genop(
        &self,
        n: Option<usize>,
        normal_form: &NormalForm,
        defs: &mut Defs<Term>,
    ) -> Result<Rational64, Error> {
        match self {
            GenOp::Named { name, seed } => {
                let generator = handle_id_error(name, defs)?;
                match generator {
                    Term::Gen(mut gen) => {
                        gen.set_seed(*seed);
                        gen.get_length_ratio_genop(n, normal_form, defs)
                    }
                    _ => Err(error_non_generator()),
                }
            }
            GenOp::Const { gen, seed } => {
                let n = if let Some(n) = n { n } else { gen.lcm_length() };
                Ok(gen.get_length(n, *seed, normal_form, defs)?)
            }
            GenOp::Taken { n, gen, seed } => {
                let mut gen = gen.to_owned();
                gen.set_seed(*seed);
                gen.get_length_ratio_genop(Some(*n), normal_form, defs)
            }
        }
    }
}

impl Generator {
    pub fn get_length(
        &self,
        n: usize,
        seed: u64,
        normal_form: &NormalForm,
        defs: &mut Defs<Term>,
    ) -> Result<Rational64, Error> {
        let mut lengths = vec![Rational64::new(1, 1); n];
        let mut rng: StdRng = SeedableRng::seed_from_u64(seed);
        let mut copy = self.clone();

        for length in lengths.iter_mut() {
            for coef in copy.coefs.iter_mut() {
                match coef.axis {
                    Axis::L => {
                        let l = coef
                            .generate(&mut rng)?
                            .get_length_ratio(normal_form, defs)?;
                        *length *= l
                    }
                    _ => {
                        coef.generate(&mut rng)?;
                    }
                }
            }
        }

        let result = Ok(lengths
            .iter()
            .fold(Rational64::from_integer(0), |current, val| current + *val));
        result
    }

    pub fn lcm_length(&self) -> usize {
        let lengths: Vec<usize> = self
            .coefs
            .iter()
            .map(|coef| match &coef.coefs {
                Coefs::Const(c) => c.len(),
                Coefs::Poly(_) => coef.div - 1,
                Coefs::Expr { .. } => coef.div - 1,
            })
            .collect();
        1 + lengths
            .iter()
            .fold(1usize, |current, val| lcm(current, *val))
    }
}