rustnomial/polynomial/
find_roots.rs

1use alloc::vec::Vec;
2use core::ops::{Add, AddAssign, Div, DivAssign, Mul, Neg, Sub, SubAssign};
3
4use num::{Complex, One, Zero};
5use roots::find_roots_sturm;
6
7use crate::numerics::{AbsSqrt, Cbrt, IsPositive};
8use crate::polynomial::polynomial::{first_nonzero_index, first_term};
9use crate::{Degree, Evaluable, SizedPolynomial, Term};
10
11#[derive(Clone, Debug, PartialEq)]
12pub enum Roots<N> {
13    NoRoots,
14    NoRootsFound,
15    OneRealRoot(N),
16    TwoRealRoots(N, N),
17    ThreeRealRoots(N, N, N),
18    ManyRealRoots(Vec<N>),
19    OneComplexRoot(Complex<N>),
20    TwoComplexRoots(Complex<N>, Complex<N>),
21    ThreeComplexRoots(Complex<N>, Complex<N>, Complex<N>),
22    ManyComplexRoots(Vec<Complex<N>>),
23    InfiniteRoots,
24    OnlyRealRoots(Vec<f64>),
25}
26
27#[inline(always)]
28pub(crate) fn discriminant_trinomial<N>(a: N, b: N, c: N) -> N
29where
30    N: Copy + Mul<Output = N> + Sub<Output = N> + From<u8>,
31{
32    b * b - a * c * N::from(4)
33}
34
35pub(crate) fn trinomial_roots<N>(a: N, b: N, c: N) -> Roots<N>
36where
37    N: Copy
38        + Mul<Output = N>
39        + Div<Output = N>
40        + Sub<Output = N>
41        + Add<Output = N>
42        + AbsSqrt
43        + IsPositive
44        + Zero
45        + Neg<Output = N>
46        + From<u8>,
47{
48    let discriminant = discriminant_trinomial(a, b, c);
49    let a = a * N::from(2);
50    let b = -b / a;
51
52    if discriminant.is_zero() {
53        return Roots::TwoRealRoots(b, b);
54    }
55
56    let sqrt = discriminant.abs_sqrt() / a;
57    if discriminant.is_positive() {
58        Roots::TwoRealRoots(b + sqrt, b - sqrt)
59    } else {
60        Roots::TwoComplexRoots(Complex::new(b, sqrt), Complex::new(b, -sqrt))
61    }
62}
63
64#[allow(clippy::many_single_char_names)]
65pub(crate) fn cubic_roots<N>(a: N, b: N, c: N, d: N) -> Roots<N>
66where
67    N: Copy
68        + Mul<Output = N>
69        + Div<Output = N>
70        + Sub<Output = N>
71        + Add<Output = N>
72        + AbsSqrt
73        + Cbrt
74        + IsPositive
75        + Zero
76        + One
77        + Neg<Output = N>
78        + From<u8>,
79{
80    let sqr = |x: N| x * x;
81    let cub = |x: N| x * x * x;
82    let p = -b / (N::from(3) * a);
83    let q = cub(p) + (b * c - N::from(3) * a * d) / (N::from(6) * sqr(a));
84    let r = c / (N::from(3) * a);
85    let k = (sqr(q) + cub(r - sqr(p))).abs_sqrt();
86    let x = (q + k).cbrt() + (q - k).cbrt() + p;
87
88    let b = b / a + x;
89    let c = c / a + b * x;
90    let roots = trinomial_roots(N::one(), b, c);
91    match roots {
92        Roots::TwoRealRoots(a, b) => Roots::ThreeRealRoots(x, a, b),
93        Roots::TwoComplexRoots(a, b) => Roots::ThreeComplexRoots(Complex::new(x, N::zero()), a, b),
94        _ => unreachable!(),
95    }
96}
97
98// pub fn complex_roots_quartic<N>(a: N, b: N, c: N, d: N, e: N) -> (Complex<N>, Complex<N>, Complex<N>, Complex<N>)
99// where
100//     N: Copy
101//         + Mul<Output = N>
102//         + Div<Output = N>
103//         + Sub<Output = N>
104//         + Add<Output = N>
105//         + AbsSqrt
106//         + Cbrt
107//         + IsPositive
108//         + Zero
109//         + One
110//         + Neg<Output = N>
111//         + From<u8>
112//         + PartialOrd
113// {
114//     let sqr = |x: N| x * x;
115//     let cub = |x: N| x * x * x;
116// }
117
118fn div<
119    N: Zero + Copy + Neg<Output = N> + AddAssign + SubAssign + Mul<Output = N> + Div<Output = N> + One,
120>(
121    values: &mut [N],
122    root: N,
123) -> Vec<N> {
124    let zero = N::zero();
125    let rhs_first = N::one();
126
127    let (mut coeff, mut self_degree) = match first_term(&values) {
128        Term::ZeroTerm => return vec![],
129        Term::Term(_, 1) => return vec![],
130        Term::Term(coeff, degree) => (coeff, degree),
131    };
132
133    let mut div = vec![zero; self_degree];
134    let offset = self_degree;
135
136    while self_degree >= 1 {
137        let scale = coeff / rhs_first;
138        let loc = values.len() - self_degree - 1;
139        values[loc] -= rhs_first * scale;
140        values[loc + 1] += root * scale;
141        div[offset - self_degree] = scale;
142        match first_term(&values) {
143            Term::ZeroTerm => break,
144            Term::Term(coeffx, degree) => {
145                coeff = coeffx;
146                self_degree = degree;
147            }
148        }
149    }
150    div
151}
152
153fn normalize<N: Zero + Copy + DivAssign>(values: &mut [N]) {
154    let f_i = first_nonzero_index(values);
155    if f_i == values.len() {
156        return;
157    }
158    let first = values[f_i];
159    for val in values[f_i..].iter_mut() {
160        *val /= first;
161    }
162}
163
164/// Finds roots for special cases (eg. cubic polynomials and below, and monomials).
165pub(crate) fn find_roots_special(poly: &[(f64, usize)]) -> Option<Roots<f64>> {
166    Some(match poly {
167        [] => Roots::InfiniteRoots,
168        [(_, 0)] => Roots::NoRoots,
169        [_] => Roots::ManyRealRoots(vec![0.]),
170        [(c1, 1), (c2, 0)] => Roots::ManyRealRoots(vec![-*c2 / *c1]),
171        [(a, 2), one_or_more @ ..] => {
172            let (b, c) = match one_or_more {
173                [] => (0., 0.),
174                [(xc, 0)] => (0., *xc),
175                [(xb, 1)] => (*xb, 0.),
176                [(xb, 1), (xc, 0)] => (*xb, *xc),
177                _ => unreachable!(),
178            };
179            match trinomial_roots(*a, b, c) {
180                Roots::TwoComplexRoots(a, b) => Roots::ManyComplexRoots(vec![a, b]),
181                Roots::TwoRealRoots(a, b) => Roots::ManyRealRoots(vec![a, b]),
182                _ => unreachable!(),
183            }
184        }
185        [(a, 3), one_or_more @ ..] => {
186            let (b, c, d) = match one_or_more {
187                [] => (0., 0., 0.),
188                [(xd, 0)] => (0., 0., *xd),
189                [(xc, 1)] => (0., *xc, 0.),
190                [(xc, 1), (xd, 0)] => (0., *xc, *xd),
191                [(xb, 2)] => (*xb, 0., 0.),
192                [(xb, 2), (xd, 0)] => (*xb, 0., *xd),
193                [(xb, 2), (xc, 1)] => (*xb, *xc, 0.),
194                [(xb, 2), (xc, 1), (xd, 0)] => (*xb, *xc, *xd),
195                _ => unreachable!(),
196            };
197            match cubic_roots(*a, b, c, d) {
198                Roots::ThreeComplexRoots(a, b, c) => Roots::ManyComplexRoots(vec![a, b, c]),
199                Roots::ThreeRealRoots(a, b, c) => Roots::ManyRealRoots(vec![a, b, c]),
200                _ => unreachable!(),
201            }
202        }
203        _ => return None,
204    })
205}
206
207/// Finds the roots of the polynomial with terms defined by the given vector, where each element
208/// is a tuple consisting of the coefficient and degree. Order is not guaranteed.
209pub(crate) fn find_roots<S: SizedPolynomial<f64> + Evaluable<f64>>(poly: &S) -> Roots<f64> {
210    let vals = poly.terms_as_vec();
211
212    if let Some(roots) = find_roots_special(&vals) {
213        return roots;
214    }
215
216    // NOTE: According to
217    // https://en.wikipedia.org/wiki/Geometrical_properties_of_polynomial_roots
218    // the largest root can be no larger than the largest coefficient divided by the
219    // coefficient of the degree 0 term (assuming it isn't zero - but in that case,
220    // we can just divide the polynomial by x).
221    let (leading, degree) = vals[0];
222    let mut values = vec![0f64; degree + 1];
223    let mut nvalues = vec![0f64; degree + 1];
224
225    nvalues[0] = leading;
226    for (val, val_deg) in vals[1..].iter() {
227        values[degree - val_deg] = *val / leading;
228        nvalues[degree - val_deg] = *val;
229    }
230
231    let mut roots = vec![];
232    loop {
233        let temp_roots: Vec<f64> = find_roots_sturm(&values[1..], &mut 1e-8f64)
234            .into_iter()
235            .filter_map(Result::ok)
236            .collect();
237
238        if temp_roots.is_empty() {
239            match poly.degree() {
240                Degree::Num(x) => {
241                    if x == temp_roots.len() {
242                        return Roots::ManyRealRoots(roots);
243                    }
244                }
245                _ => unreachable!("Polynomial should not be zero in this stage."),
246            }
247            return if roots.is_empty() {
248                Roots::NoRoots
249            } else {
250                Roots::OnlyRealRoots(roots)
251            };
252        }
253
254        for root in temp_roots {
255            let root = {
256                let x = root.round();
257                if poly.eval(x).abs() < poly.eval(root).abs() {
258                    x
259                } else {
260                    root
261                }
262            };
263            roots.push(root);
264            nvalues = div(&mut nvalues, root);
265        }
266
267        if nvalues.is_empty() {
268            return if roots.is_empty() {
269                Roots::NoRoots
270            } else {
271                Roots::ManyRealRoots(roots)
272            };
273        }
274        normalize(&mut nvalues);
275        let leading = nvalues[0];
276        values = nvalues
277            .iter()
278            .map(|&val| val / leading)
279            .collect::<Vec<f64>>();
280    }
281}
282
283#[cfg(test)]
284mod test {
285    use crate::polynomial::find_roots::{cubic_roots, find_roots};
286    use crate::{LinearBinomial, Monomial, Polynomial, Roots, SizedPolynomial};
287
288    #[test]
289    fn test_roots_empty() {
290        let p = Polynomial::<f64>::zero();
291        assert_eq!(Roots::InfiniteRoots, find_roots(&p));
292    }
293
294    #[test]
295    fn test_roots_constant() {
296        let p = Monomial::new(1., 0);
297        assert_eq!(Roots::NoRoots, find_roots(&p));
298    }
299
300    #[test]
301    fn test_roots_binomial() {
302        let p = LinearBinomial::new([1., 2.]);
303        assert_eq!(Roots::ManyRealRoots(vec![-2.]), find_roots(&p));
304    }
305
306    #[test]
307    fn test_roots_cubic_a_equals_one() {
308        assert_eq!(
309            Roots::ThreeRealRoots(-2.0, -2.0, -2.0),
310            cubic_roots(1f64, 6., 12., 8.)
311        );
312    }
313
314    #[test]
315    fn test_roots_cubic_a_does_not_equal_one() {
316        assert_eq!(
317            Roots::ThreeRealRoots(-2.0, -2.0, -2.0),
318            cubic_roots(2f64, 12., 24., 16.)
319        );
320    }
321
322    #[test]
323    fn test_cubic_polynomials() {
324        let p = Polynomial::new(vec![1f64, 6., 12., 8.]);
325        assert_eq!(Roots::ManyRealRoots(vec![-2., -2., -2.]), find_roots(&p));
326    }
327
328    #[test]
329    fn test_large_polynomials() {
330        let p = Polynomial::new(vec![1f64, 2.]).pow(9) * Polynomial::new(vec![1f64, 3.]);
331        assert_eq!(
332            Roots::ManyRealRoots(vec![-3., -2., -2., -2., -2., -2., -2., -2., -2., -2.]),
333            find_roots(&p)
334        );
335    }
336
337    #[test]
338    fn test_quad_no_real_roots() {
339        let p = Polynomial::<f64>::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]);
340        assert_eq!(Roots::NoRoots, find_roots(&p));
341    }
342
343    // #[test]
344    // fn test_large_polynomials_fractional() {
345    //     let p = Polynomial::new(vec![1f64, 2./3.]).pow(6) * Polynomial::new(vec![1f64, 3.]);
346    //     assert_eq!(Roots::ManyRealRoots(vec![-3., 2./3., 2./3., 2./3., 2./3., 2./3., 2./3.]), find_roots(&p));
347    // }
348
349    // #[test]
350    // fn test_roots_quartic_a_equals_one() {
351    //     let c = Complex::new(-2.0, 0.);
352    //     assert_eq!((c, c, c, c), complex_roots_quartic(1f32, 8., 24., 32., 16.));
353    // }
354    //
355    // #[test]
356    // fn test_roots_quartic_a_does_not_equal_one() {
357    //     let c = Complex::new(-2.0, 0.);
358    //     assert_eq!((c, c, c, c), complex_roots_quartic(2f32, 16., 48., 64., 32.));
359    // }
360}