rust_poly/poly/
roots.rs

1use crate::{
2    util::{
3        complex::{c_from_f128, c_neg, c_sqrt, c_to_f128},
4        vec::slice_mean,
5    },
6    Poly, RealScalar,
7};
8use anyhow::anyhow;
9use f128::f128;
10use itertools::Itertools;
11use num::{Complex, FromPrimitive, One, Zero};
12
13mod single_root;
14pub use single_root::{halley, naive, newton};
15mod all_roots;
16pub use all_roots::{aberth_ehrlich, deflate, halley_deflate, naive_deflate, newton_deflate};
17mod many_roots;
18pub use many_roots::{halley_parallel, naive_parallel, newton_parallel, parallel};
19mod initial_guess;
20pub use initial_guess::{initial_guess_smallest, initial_guesses_circle};
21
22#[derive(thiserror::Error, Debug)]
23#[non_exhaustive]
24pub enum Error<T> {
25    #[error("root finder did not converge within the given constraints")]
26    NoConverge(T),
27
28    #[error("unexpected error while running root finder")]
29    Other(#[from] anyhow::Error),
30}
31
32// TODO: make a type that contains results with some extra info and an `.unpack_roots` method.
33
34pub type Result<T> = std::result::Result<Vec<Complex<T>>, Error<Vec<Complex<T>>>>;
35
36// TODO: use everywhere
37pub type Roots<T> = Vec<Complex<T>>;
38
39pub enum PolishingMode<T> {
40    None,
41    StandardPrecision {
42        epsilon: T,
43        min_iter: usize,
44        max_iter: usize,
45    },
46    #[cfg(target_arch = "x86_64")]
47    HighPrecision {
48        epsilon: T,
49        min_iter: usize,
50        max_iter: usize,
51    },
52}
53
54pub enum MultiplesHandlingMode<T> {
55    None,
56    BroadcastBest { detection_epsilon: T },
57    BroadcastAverage { detection_epsilon: T },
58    KeepBest { detection_epsilon: T },
59    KeepAverage { detection_epsilon: T },
60}
61
62pub enum InitialGuessMode<T> {
63    GuessPoolOnly,
64    RandomAnnulus { bias: T, perturbation: T, seed: u64 },
65    // TODO: Hull {},
66    // TODO: GridSearch {},
67}
68
69impl<T: RealScalar> Poly<T> {
70    /// A convenient way of finding roots, with a pre-configured root finder.
71    /// Should work well for most real polynomials of low degree.
72    ///
73    /// Use [`Poly::roots_expert`] if you need more control over performance or accuracy.
74    ///
75    /// # Errors
76    /// - Solver did not converge within `max_iter` iterations
77    /// - Some other edge-case was encountered which could not be handled (please
78    ///   report this, as we can make this solver more robust!)
79    pub fn roots(&self, epsilon: T, max_iter: usize) -> Result<T> {
80        self.roots_expert(
81            epsilon.clone(),
82            max_iter,
83            0,
84            PolishingMode::StandardPrecision {
85                epsilon: epsilon.clone(),
86                min_iter: 0,
87                max_iter,
88            },
89            MultiplesHandlingMode::BroadcastBest {
90                // TODO: tune ratio
91                detection_epsilon: epsilon * T::from_f64(1.5).expect("overflow"),
92            },
93            &[],
94            InitialGuessMode::RandomAnnulus {
95                bias: T::from_f64(0.5).expect("overflow"),
96                perturbation: T::from_f64(0.5).expect("overflow"),
97                seed: 1,
98            },
99        )
100    }
101
102    /// Highly configurable root finder.
103    ///
104    /// [`Poly::roots`] will often be good enough, but you may know something
105    /// about the polynomial you are factoring that allows you to tweak the
106    /// settings.
107    ///
108    /// # Errors
109    /// - Solver did not converge within `max_iter` iterations
110    /// - Some other edge-case was encountered which could not be handled (please
111    ///   report this, as we can make this solver more robust!)
112    /// - The combination of parameters that was provided is invalid
113    pub fn roots_expert(
114        &self,
115        epsilon: T,
116        max_iter: usize,
117        _min_iter: usize,
118        polishing_mode: PolishingMode<T>,
119        multiples_handling_mode: MultiplesHandlingMode<T>,
120        initial_guess_pool: &[Complex<T>],
121        initial_guess_mode: InitialGuessMode<T>,
122    ) -> Result<T> {
123        debug_assert!(self.is_normalized());
124
125        let mut this = self.clone();
126
127        let mut roots: Vec<Complex<T>> = this.zero_roots(epsilon.clone());
128
129        match this.degree_raw() {
130            1 => {
131                roots.extend(this.linear_roots());
132                return Ok(roots);
133            }
134            2 => {
135                roots.extend(this.quadratic_roots());
136                return Ok(roots);
137            }
138            _ => {}
139        }
140
141        this.make_monic();
142
143        debug_assert!(this.is_normalized());
144        let mut initial_guesses = Vec::with_capacity(this.degree_raw());
145        for guess in initial_guess_pool.iter().cloned() {
146            initial_guesses.push(guess);
147        }
148
149        // fill remaining guesses with zeros and prepare to replace with computed
150        // initial guesses
151        let delta = this.degree_raw() - initial_guesses.len();
152        for _ in 0..delta {
153            initial_guesses.push(Complex::<T>::zero());
154        }
155        let remaining_guesses_view =
156            &mut initial_guesses[initial_guess_pool.len()..this.degree_raw()];
157
158        match initial_guess_mode {
159            InitialGuessMode::GuessPoolOnly => {
160                if initial_guess_pool.len() < this.degree_raw() {
161                    return Err(Error::Other(anyhow!("not enough initial guesses, you must provide one guess per root when using GuessPoolOnly")));
162                }
163            }
164            InitialGuessMode::RandomAnnulus {
165                bias,
166                perturbation,
167                seed,
168            } => {
169                initial_guesses_circle(&this, bias, seed, perturbation, remaining_guesses_view);
170            } // TODO: InitialGuessMode::Hull {} => todo!(),
171              // TODO: InitialGuessMode::GridSearch {} => todo!(),
172        }
173
174        log::trace!("{initial_guesses:?}");
175
176        roots.extend(aberth_ehrlich(
177            &mut this,
178            Some(epsilon.clone()),
179            Some(max_iter),
180            &initial_guesses,
181        )?);
182
183        // further polishing of roots
184        let roots: Roots<T> = match polishing_mode {
185            PolishingMode::None => Ok(roots),
186            PolishingMode::StandardPrecision {
187                epsilon,
188                min_iter,
189                max_iter,
190            } => newton_parallel(&mut this, Some(epsilon), Some(max_iter), &roots),
191
192            #[cfg(target_arch = "x86_64")]
193            PolishingMode::HighPrecision {
194                epsilon,
195                min_iter,
196                max_iter,
197            } => {
198                let mut this = this.clone().cast_to_f128();
199                let roots = roots.iter().cloned().map(|z| c_to_f128(z)).collect_vec();
200                newton_parallel(
201                    &mut this,
202                    Some(f128::from(epsilon.to_f64().expect("overflow"))),
203                    Some(max_iter),
204                    &roots,
205                )
206                .map(|v| v.into_iter().map(|z| c_from_f128::<T>(z)).collect_vec())
207                .map_err(|e| match e {
208                    Error::NoConverge(v) => {
209                        Error::NoConverge(v.into_iter().map(|z| c_from_f128::<T>(z)).collect_vec())
210                    }
211                    Error::Other(o) => Error::Other(o),
212                })
213            }
214        }?;
215
216        match multiples_handling_mode {
217            MultiplesHandlingMode::None => Ok(roots),
218            MultiplesHandlingMode::BroadcastBest { detection_epsilon } => Ok(best_multiples(
219                &this,
220                group_multiples(roots, detection_epsilon),
221                true,
222            )),
223            MultiplesHandlingMode::BroadcastAverage { detection_epsilon } => Ok(average_multiples(
224                &this,
225                group_multiples(roots, detection_epsilon),
226                true,
227            )),
228            MultiplesHandlingMode::KeepBest { detection_epsilon } => Ok(best_multiples(
229                &this,
230                group_multiples(roots, detection_epsilon),
231                false,
232            )),
233            MultiplesHandlingMode::KeepAverage { detection_epsilon } => Ok(average_multiples(
234                &this,
235                group_multiples(roots, detection_epsilon),
236                false,
237            )),
238        }
239    }
240}
241
242// private
243impl<T: RealScalar> Poly<T> {
244    fn zero_roots(&mut self, epsilon: T) -> Vec<Complex<T>> {
245        debug_assert!(self.is_normalized());
246
247        let mut roots = vec![];
248        for _ in 0..self.degree_raw() {
249            if self.eval(Complex::zero()).norm_sqr() < epsilon {
250                roots.push(Complex::zero());
251                // deflating zero roots can be accomplished simply by shifting
252                *self = self.shift_down(1);
253            } else {
254                break;
255            }
256        }
257
258        roots
259    }
260
261    fn linear_roots(&mut self) -> Vec<Complex<T>> {
262        debug_assert!(self.is_normalized());
263        debug_assert_eq!(self.degree_raw(), 1);
264
265        self.trim();
266        if self.degree_raw() < 1 {
267            return vec![];
268        }
269
270        let a = self.0[1].clone();
271        let b = self.0[0].clone();
272
273        // we found all the roots
274        *self = Self::one();
275
276        vec![-b / a]
277    }
278
279    /// Quadratic formula
280    fn quadratic_roots(&mut self) -> Vec<Complex<T>> {
281        debug_assert!(self.is_normalized());
282        debug_assert_eq!(self.degree_raw(), 2);
283
284        // trimming trailing almost zeros to avoid overflow
285        self.trim();
286        if self.degree_raw() == 1 {
287            return self.linear_roots();
288        }
289        if self.degree_raw() == 0 {
290            return vec![];
291        }
292
293        let a = self.0[2].clone();
294        let b = self.0[1].clone();
295        let c = self.0[0].clone();
296        let four = Complex::<T>::from_u8(4).expect("overflow");
297        let two = Complex::<T>::from_u8(2).expect("overflow");
298
299        // TODO: switch to different formula when b^2 and 4c are very close due
300        //       to loss of precision
301        let plus_minus_term = c_sqrt(b.clone() * b.clone() - four * a.clone() * c);
302        let x1 = (plus_minus_term.clone() - b.clone()) / (two.clone() * a.clone());
303        let x2 = (c_neg(b.clone()) - plus_minus_term) / (two * a);
304
305        // we found all the roots
306        *self = Self::one();
307
308        vec![x1, x2]
309    }
310}
311
312/// Find roots that are within a given tolerance from each other and group them
313fn group_multiples<T: RealScalar>(roots: Roots<T>, epsilon: T) -> Vec<Roots<T>> {
314    // groups with their respective mean
315    let mut groups: Vec<(Roots<T>, Complex<T>)> = vec![];
316
317    let mut roots = roots;
318
319    while !roots.is_empty() {
320        // now for each root we find a group whose median is within tolerance,
321        // if we don't find any we add a new group with the one root
322        // if we do, we add the point, update the mean
323        'roots_loop: for root in roots.drain(..) {
324            for group in &mut groups {
325                if (group.1.clone() - root.clone()).norm_sqr() <= epsilon {
326                    group.0.push(root.clone());
327                    group.1 = slice_mean(&group.0);
328                    continue 'roots_loop;
329                }
330            }
331            groups.push((vec![root.clone()], root));
332        }
333
334        // now we loop through all the groups and through each element in each group
335        // and remove any elements that now lay outside of the tolerance
336        for group in &mut groups {
337            // hijacking retain to avoid having to write a loop where we delete
338            // things from the collection we're iterating from.
339            group.0.retain(|r| {
340                if (r.clone() - group.1.clone()).norm_sqr() <= epsilon {
341                    true
342                } else {
343                    roots.push(r.clone());
344                    false
345                }
346            });
347        }
348
349        // finally we prune empty groups
350        groups.retain(|g| !g.0.is_empty());
351    }
352
353    groups.into_iter().map(|(r, _)| r).collect_vec()
354}
355
356fn best_multiples<T: RealScalar>(
357    poly: &Poly<T>,
358    groups: Vec<Roots<T>>,
359    do_broadcast: bool,
360) -> Roots<T> {
361    // find the best root in each group
362    groups
363        .into_iter()
364        .flat_map(|group| {
365            let len = group.len();
366            let best = group
367                .into_iter()
368                .map(|root| (root.clone(), poly.eval(root).norm_sqr()))
369                .reduce(|(a_root, a_eval), (b_root, b_eval)| {
370                    if a_eval < b_eval {
371                        (a_root, a_eval)
372                    } else {
373                        (b_root, b_eval)
374                    }
375                })
376                .expect("empty groups not allowed")
377                .0;
378            if do_broadcast {
379                vec![best; len]
380            } else {
381                vec![best]
382            }
383        })
384        .collect_vec()
385}
386
387fn average_multiples<T: RealScalar>(
388    poly: &Poly<T>,
389    groups: Vec<Roots<T>>,
390    do_broadcast: bool,
391) -> Roots<T> {
392    groups
393        .into_iter()
394        .flat_map(|group| {
395            let len_usize = group.len();
396            debug_assert!(len_usize > 0);
397            let len = T::from_usize(len_usize).expect("infallible");
398            let sum: Complex<T> = group.into_iter().sum();
399            let avg = sum / len;
400            if do_broadcast {
401                vec![avg; len_usize]
402            } else {
403                vec![avg]
404            }
405        })
406        .collect_vec()
407}
408
409#[cfg(test)]
410mod test {
411    use num::complex::ComplexFloat;
412
413    use crate::Poly64;
414
415    /// See [#3](https://github.com/PanieriLorenzo/rust-poly/issues/3)
416    #[test]
417    fn roots_of_reverse_bessel() {
418        let poly = Poly64::reverse_bessel(2).unwrap();
419        let roots = poly.roots(1E-10, 1000).unwrap();
420        assert!((roots[0].re() - -1.5).abs() < 0.01);
421        assert!((roots[0].im().abs() - 0.866).abs() < 0.01);
422        assert!((roots[1].re() - -1.5).abs() < 0.01);
423        assert!((roots[1].im().abs() - 0.866).abs() < 0.01);
424    }
425}