1use alloc::vec::Vec;
2use core::ops::{Add, AddAssign, Div, DivAssign, Mul, Neg, Sub, SubAssign};
3
4use num::{Complex, One, Zero};
5use roots::find_roots_sturm;
6
7use crate::numerics::{AbsSqrt, Cbrt, IsPositive};
8use crate::polynomial::polynomial::{first_nonzero_index, first_term};
9use crate::{Degree, Evaluable, SizedPolynomial, Term};
10
11#[derive(Clone, Debug, PartialEq)]
12pub enum Roots<N> {
13 NoRoots,
14 NoRootsFound,
15 OneRealRoot(N),
16 TwoRealRoots(N, N),
17 ThreeRealRoots(N, N, N),
18 ManyRealRoots(Vec<N>),
19 OneComplexRoot(Complex<N>),
20 TwoComplexRoots(Complex<N>, Complex<N>),
21 ThreeComplexRoots(Complex<N>, Complex<N>, Complex<N>),
22 ManyComplexRoots(Vec<Complex<N>>),
23 InfiniteRoots,
24 OnlyRealRoots(Vec<f64>),
25}
26
27#[inline(always)]
28pub(crate) fn discriminant_trinomial<N>(a: N, b: N, c: N) -> N
29where
30 N: Copy + Mul<Output = N> + Sub<Output = N> + From<u8>,
31{
32 b * b - a * c * N::from(4)
33}
34
35pub(crate) fn trinomial_roots<N>(a: N, b: N, c: N) -> Roots<N>
36where
37 N: Copy
38 + Mul<Output = N>
39 + Div<Output = N>
40 + Sub<Output = N>
41 + Add<Output = N>
42 + AbsSqrt
43 + IsPositive
44 + Zero
45 + Neg<Output = N>
46 + From<u8>,
47{
48 let discriminant = discriminant_trinomial(a, b, c);
49 let a = a * N::from(2);
50 let b = -b / a;
51
52 if discriminant.is_zero() {
53 return Roots::TwoRealRoots(b, b);
54 }
55
56 let sqrt = discriminant.abs_sqrt() / a;
57 if discriminant.is_positive() {
58 Roots::TwoRealRoots(b + sqrt, b - sqrt)
59 } else {
60 Roots::TwoComplexRoots(Complex::new(b, sqrt), Complex::new(b, -sqrt))
61 }
62}
63
64#[allow(clippy::many_single_char_names)]
65pub(crate) fn cubic_roots<N>(a: N, b: N, c: N, d: N) -> Roots<N>
66where
67 N: Copy
68 + Mul<Output = N>
69 + Div<Output = N>
70 + Sub<Output = N>
71 + Add<Output = N>
72 + AbsSqrt
73 + Cbrt
74 + IsPositive
75 + Zero
76 + One
77 + Neg<Output = N>
78 + From<u8>,
79{
80 let sqr = |x: N| x * x;
81 let cub = |x: N| x * x * x;
82 let p = -b / (N::from(3) * a);
83 let q = cub(p) + (b * c - N::from(3) * a * d) / (N::from(6) * sqr(a));
84 let r = c / (N::from(3) * a);
85 let k = (sqr(q) + cub(r - sqr(p))).abs_sqrt();
86 let x = (q + k).cbrt() + (q - k).cbrt() + p;
87
88 let b = b / a + x;
89 let c = c / a + b * x;
90 let roots = trinomial_roots(N::one(), b, c);
91 match roots {
92 Roots::TwoRealRoots(a, b) => Roots::ThreeRealRoots(x, a, b),
93 Roots::TwoComplexRoots(a, b) => Roots::ThreeComplexRoots(Complex::new(x, N::zero()), a, b),
94 _ => unreachable!(),
95 }
96}
97
98fn div<
119 N: Zero + Copy + Neg<Output = N> + AddAssign + SubAssign + Mul<Output = N> + Div<Output = N> + One,
120>(
121 values: &mut [N],
122 root: N,
123) -> Vec<N> {
124 let zero = N::zero();
125 let rhs_first = N::one();
126
127 let (mut coeff, mut self_degree) = match first_term(&values) {
128 Term::ZeroTerm => return vec![],
129 Term::Term(_, 1) => return vec![],
130 Term::Term(coeff, degree) => (coeff, degree),
131 };
132
133 let mut div = vec![zero; self_degree];
134 let offset = self_degree;
135
136 while self_degree >= 1 {
137 let scale = coeff / rhs_first;
138 let loc = values.len() - self_degree - 1;
139 values[loc] -= rhs_first * scale;
140 values[loc + 1] += root * scale;
141 div[offset - self_degree] = scale;
142 match first_term(&values) {
143 Term::ZeroTerm => break,
144 Term::Term(coeffx, degree) => {
145 coeff = coeffx;
146 self_degree = degree;
147 }
148 }
149 }
150 div
151}
152
153fn normalize<N: Zero + Copy + DivAssign>(values: &mut [N]) {
154 let f_i = first_nonzero_index(values);
155 if f_i == values.len() {
156 return;
157 }
158 let first = values[f_i];
159 for val in values[f_i..].iter_mut() {
160 *val /= first;
161 }
162}
163
164pub(crate) fn find_roots_special(poly: &[(f64, usize)]) -> Option<Roots<f64>> {
166 Some(match poly {
167 [] => Roots::InfiniteRoots,
168 [(_, 0)] => Roots::NoRoots,
169 [_] => Roots::ManyRealRoots(vec![0.]),
170 [(c1, 1), (c2, 0)] => Roots::ManyRealRoots(vec![-*c2 / *c1]),
171 [(a, 2), one_or_more @ ..] => {
172 let (b, c) = match one_or_more {
173 [] => (0., 0.),
174 [(xc, 0)] => (0., *xc),
175 [(xb, 1)] => (*xb, 0.),
176 [(xb, 1), (xc, 0)] => (*xb, *xc),
177 _ => unreachable!(),
178 };
179 match trinomial_roots(*a, b, c) {
180 Roots::TwoComplexRoots(a, b) => Roots::ManyComplexRoots(vec![a, b]),
181 Roots::TwoRealRoots(a, b) => Roots::ManyRealRoots(vec![a, b]),
182 _ => unreachable!(),
183 }
184 }
185 [(a, 3), one_or_more @ ..] => {
186 let (b, c, d) = match one_or_more {
187 [] => (0., 0., 0.),
188 [(xd, 0)] => (0., 0., *xd),
189 [(xc, 1)] => (0., *xc, 0.),
190 [(xc, 1), (xd, 0)] => (0., *xc, *xd),
191 [(xb, 2)] => (*xb, 0., 0.),
192 [(xb, 2), (xd, 0)] => (*xb, 0., *xd),
193 [(xb, 2), (xc, 1)] => (*xb, *xc, 0.),
194 [(xb, 2), (xc, 1), (xd, 0)] => (*xb, *xc, *xd),
195 _ => unreachable!(),
196 };
197 match cubic_roots(*a, b, c, d) {
198 Roots::ThreeComplexRoots(a, b, c) => Roots::ManyComplexRoots(vec![a, b, c]),
199 Roots::ThreeRealRoots(a, b, c) => Roots::ManyRealRoots(vec![a, b, c]),
200 _ => unreachable!(),
201 }
202 }
203 _ => return None,
204 })
205}
206
207pub(crate) fn find_roots<S: SizedPolynomial<f64> + Evaluable<f64>>(poly: &S) -> Roots<f64> {
210 let vals = poly.terms_as_vec();
211
212 if let Some(roots) = find_roots_special(&vals) {
213 return roots;
214 }
215
216 let (leading, degree) = vals[0];
222 let mut values = vec![0f64; degree + 1];
223 let mut nvalues = vec![0f64; degree + 1];
224
225 nvalues[0] = leading;
226 for (val, val_deg) in vals[1..].iter() {
227 values[degree - val_deg] = *val / leading;
228 nvalues[degree - val_deg] = *val;
229 }
230
231 let mut roots = vec![];
232 loop {
233 let temp_roots: Vec<f64> = find_roots_sturm(&values[1..], &mut 1e-8f64)
234 .into_iter()
235 .filter_map(Result::ok)
236 .collect();
237
238 if temp_roots.is_empty() {
239 match poly.degree() {
240 Degree::Num(x) => {
241 if x == temp_roots.len() {
242 return Roots::ManyRealRoots(roots);
243 }
244 }
245 _ => unreachable!("Polynomial should not be zero in this stage."),
246 }
247 return if roots.is_empty() {
248 Roots::NoRoots
249 } else {
250 Roots::OnlyRealRoots(roots)
251 };
252 }
253
254 for root in temp_roots {
255 let root = {
256 let x = root.round();
257 if poly.eval(x).abs() < poly.eval(root).abs() {
258 x
259 } else {
260 root
261 }
262 };
263 roots.push(root);
264 nvalues = div(&mut nvalues, root);
265 }
266
267 if nvalues.is_empty() {
268 return if roots.is_empty() {
269 Roots::NoRoots
270 } else {
271 Roots::ManyRealRoots(roots)
272 };
273 }
274 normalize(&mut nvalues);
275 let leading = nvalues[0];
276 values = nvalues
277 .iter()
278 .map(|&val| val / leading)
279 .collect::<Vec<f64>>();
280 }
281}
282
283#[cfg(test)]
284mod test {
285 use crate::polynomial::find_roots::{cubic_roots, find_roots};
286 use crate::{LinearBinomial, Monomial, Polynomial, Roots, SizedPolynomial};
287
288 #[test]
289 fn test_roots_empty() {
290 let p = Polynomial::<f64>::zero();
291 assert_eq!(Roots::InfiniteRoots, find_roots(&p));
292 }
293
294 #[test]
295 fn test_roots_constant() {
296 let p = Monomial::new(1., 0);
297 assert_eq!(Roots::NoRoots, find_roots(&p));
298 }
299
300 #[test]
301 fn test_roots_binomial() {
302 let p = LinearBinomial::new([1., 2.]);
303 assert_eq!(Roots::ManyRealRoots(vec![-2.]), find_roots(&p));
304 }
305
306 #[test]
307 fn test_roots_cubic_a_equals_one() {
308 assert_eq!(
309 Roots::ThreeRealRoots(-2.0, -2.0, -2.0),
310 cubic_roots(1f64, 6., 12., 8.)
311 );
312 }
313
314 #[test]
315 fn test_roots_cubic_a_does_not_equal_one() {
316 assert_eq!(
317 Roots::ThreeRealRoots(-2.0, -2.0, -2.0),
318 cubic_roots(2f64, 12., 24., 16.)
319 );
320 }
321
322 #[test]
323 fn test_cubic_polynomials() {
324 let p = Polynomial::new(vec![1f64, 6., 12., 8.]);
325 assert_eq!(Roots::ManyRealRoots(vec![-2., -2., -2.]), find_roots(&p));
326 }
327
328 #[test]
329 fn test_large_polynomials() {
330 let p = Polynomial::new(vec![1f64, 2.]).pow(9) * Polynomial::new(vec![1f64, 3.]);
331 assert_eq!(
332 Roots::ManyRealRoots(vec![-3., -2., -2., -2., -2., -2., -2., -2., -2., -2.]),
333 find_roots(&p)
334 );
335 }
336
337 #[test]
338 fn test_quad_no_real_roots() {
339 let p = Polynomial::<f64>::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]);
340 assert_eq!(Roots::NoRoots, find_roots(&p));
341 }
342
343 }