rust_poly/poly/
roots.rs

1use crate::{
2    util::{
3        complex::{c_from_f128, c_neg, c_sqrt, c_to_f128},
4        vec::slice_mean,
5    },
6    Poly, RealScalar,
7};
8use anyhow::anyhow;
9use f128::f128;
10use itertools::Itertools;
11use num::{Complex, FromPrimitive, One, Zero};
12
13mod single_root;
14pub use single_root::{halley, naive, newton};
15mod all_roots;
16pub use all_roots::{aberth_ehrlich, deflate, halley_deflate, naive_deflate, newton_deflate};
17mod many_roots;
18pub use many_roots::{halley_parallel, naive_parallel, newton_parallel, parallel};
19mod initial_guess;
20pub use initial_guess::{initial_guess_smallest, initial_guesses_circle};
21
22#[derive(thiserror::Error, Debug)]
23#[non_exhaustive]
24pub enum Error<T> {
25    #[error("root finder did not converge within the given constraints")]
26    NoConverge(T),
27
28    #[error("unexpected error while running root finder")]
29    Other(#[from] anyhow::Error),
30}
31
32// TODO: make a type that contains results with some extra info and an `.unpack_roots` method.
33
34pub type Result<T> = std::result::Result<Vec<Complex<T>>, Error<Vec<Complex<T>>>>;
35
36// TODO: use everywhere
37pub type Roots<T> = Vec<Complex<T>>;
38
39pub enum PolishingMode<T> {
40    None,
41    StandardPrecision {
42        epsilon: T,
43        min_iter: usize,
44        max_iter: usize,
45    },
46    #[cfg(target_arch = "x86_64")]
47    HighPrecision {
48        epsilon: T,
49        min_iter: usize,
50        max_iter: usize,
51    },
52}
53
54pub enum MultiplesHandlingMode<T> {
55    None,
56    BroadcastBest { detection_epsilon: T },
57    BroadcastAverage { detection_epsilon: T },
58    KeepBest { detection_epsilon: T },
59    KeepAverage { detection_epsilon: T },
60}
61
62pub enum InitialGuessMode<T> {
63    GuessPoolOnly,
64    RandomAnnulus { bias: T, perturbation: T, seed: u64 },
65    // TODO: Hull {},
66    // TODO: GridSearch {},
67}
68
69impl<T: RealScalar> Poly<T> {
70    /// A convenient way of finding roots, with a pre-configured root finder.
71    /// Should work well for most real polynomials of low degree.
72    ///
73    /// Use [`Poly::roots_expert`] if you need more control over performance or accuracy.
74    ///
75    /// # Errors
76    /// - Solver did not converge within `max_iter` iterations
77    /// - Some other edge-case was encountered which could not be handled (please
78    ///   report this, as we can make this solver more robust!)
79    pub fn roots(&self, epsilon: T, max_iter: usize) -> Result<T> {
80        self.roots_expert(
81            epsilon.clone(),
82            max_iter,
83            0,
84            PolishingMode::StandardPrecision {
85                epsilon: epsilon.clone(),
86                min_iter: 0,
87                max_iter,
88            },
89            MultiplesHandlingMode::BroadcastBest {
90                // TODO: tune ratio
91                detection_epsilon: epsilon * T::from_f64(1.5).expect("overflow"),
92            },
93            &[],
94            InitialGuessMode::RandomAnnulus {
95                bias: T::from_f64(0.5).expect("overflow"),
96                perturbation: T::from_f64(0.5).expect("overflow"),
97                seed: 1,
98            },
99        )
100    }
101
102    /// Highly configurable root finder.
103    ///
104    /// [`Poly::roots`] will often be good enough, but you may know something
105    /// about the polynomial you are factoring that allows you to tweak the
106    /// settings.
107    ///
108    /// # Errors
109    /// - Solver did not converge within `max_iter` iterations
110    /// - Some other edge-case was encountered which could not be handled (please
111    ///   report this, as we can make this solver more robust!)
112    /// - The combination of parameters that was provided is invalid
113    pub fn roots_expert(
114        &self,
115        epsilon: T,
116        max_iter: usize,
117        _min_iter: usize,
118        polishing_mode: PolishingMode<T>,
119        multiples_handling_mode: MultiplesHandlingMode<T>,
120        initial_guess_pool: &[Complex<T>],
121        initial_guess_mode: InitialGuessMode<T>,
122    ) -> Result<T> {
123        debug_assert!(self.is_normalized());
124
125        let mut this = self.clone();
126
127        let mut roots: Vec<Complex<T>> = this.zero_roots(epsilon.clone());
128
129        match this.degree_raw() {
130            1 => {
131                roots.extend(this.linear_roots());
132                return Ok(roots);
133            }
134            2 => {
135                roots.extend(this.quadratic_roots());
136                return Ok(roots);
137            }
138            _ => {}
139        }
140
141        this.make_monic();
142
143        debug_assert!(this.is_normalized());
144        let mut initial_guesses = Vec::with_capacity(this.degree_raw());
145        for guess in initial_guess_pool.iter().cloned() {
146            initial_guesses.push(guess);
147        }
148
149        // fill remaining guesses with zeros and prepare to replace with computed
150        // initial guesses
151        let delta = this.degree_raw() - initial_guesses.len();
152        for _ in 0..delta {
153            initial_guesses.push(Complex::<T>::zero());
154        }
155        let mut remaining_guesses_view =
156            &mut initial_guesses[initial_guess_pool.len()..this.degree_raw()];
157
158        match initial_guess_mode {
159            InitialGuessMode::GuessPoolOnly => {
160                if initial_guess_pool.len() < this.degree_raw() {
161                    return Err(Error::Other(anyhow!("not enough initial guesses, you must provide one guess per root when using GuessPoolOnly")));
162                }
163            }
164            InitialGuessMode::RandomAnnulus {
165                bias,
166                perturbation,
167                seed,
168            } => {
169                initial_guesses_circle(
170                    &this,
171                    bias,
172                    seed,
173                    perturbation,
174                    &mut remaining_guesses_view,
175                );
176            } // TODO: InitialGuessMode::Hull {} => todo!(),
177              // TODO: InitialGuessMode::GridSearch {} => todo!(),
178        }
179
180        log::trace!("{initial_guesses:?}");
181
182        roots.extend(aberth_ehrlich(
183            &mut this,
184            Some(epsilon.clone()),
185            Some(max_iter),
186            &initial_guesses,
187        )?);
188
189        // further polishing of roots
190        let roots: Roots<T> = match polishing_mode {
191            PolishingMode::None => Ok(roots),
192            PolishingMode::StandardPrecision {
193                epsilon,
194                min_iter,
195                max_iter,
196            } => newton_parallel(&mut this, Some(epsilon), Some(max_iter), &roots),
197
198            #[cfg(target_arch = "x86_64")]
199            PolishingMode::HighPrecision {
200                epsilon,
201                min_iter,
202                max_iter,
203            } => {
204                let mut this = this.clone().cast_to_f128();
205                let roots = roots.iter().cloned().map(|z| c_to_f128(z)).collect_vec();
206                newton_parallel(
207                    &mut this,
208                    Some(f128::from(epsilon.to_f64().expect("overflow"))),
209                    Some(max_iter),
210                    &roots,
211                )
212                .map(|v| v.into_iter().map(|z| c_from_f128::<T>(z)).collect_vec())
213                .map_err(|e| match e {
214                    Error::NoConverge(v) => {
215                        Error::NoConverge(v.into_iter().map(|z| c_from_f128::<T>(z)).collect_vec())
216                    }
217                    Error::Other(o) => Error::Other(o),
218                })
219            }
220        }?;
221
222        match multiples_handling_mode {
223            MultiplesHandlingMode::None => Ok(roots),
224            MultiplesHandlingMode::BroadcastBest { detection_epsilon } => Ok(best_multiples(
225                &this,
226                group_multiples(roots, detection_epsilon),
227                true,
228            )),
229            MultiplesHandlingMode::BroadcastAverage { detection_epsilon } => Ok(average_multiples(
230                &this,
231                group_multiples(roots, detection_epsilon),
232                true,
233            )),
234            MultiplesHandlingMode::KeepBest { detection_epsilon } => Ok(best_multiples(
235                &this,
236                group_multiples(roots, detection_epsilon),
237                false,
238            )),
239            MultiplesHandlingMode::KeepAverage { detection_epsilon } => Ok(average_multiples(
240                &this,
241                group_multiples(roots, detection_epsilon),
242                false,
243            )),
244        }
245    }
246}
247
248// private
249impl<T: RealScalar> Poly<T> {
250    fn zero_roots(&mut self, epsilon: T) -> Vec<Complex<T>> {
251        debug_assert!(self.is_normalized());
252
253        let mut roots = vec![];
254        for _ in 0..self.degree_raw() {
255            if self.eval(Complex::zero()).norm_sqr() < epsilon {
256                roots.push(Complex::zero());
257                // deflating zero roots can be accomplished simply by shifting
258                *self = self.shift_down(1);
259            } else {
260                break;
261            }
262        }
263
264        roots
265    }
266
267    fn linear_roots(&mut self) -> Vec<Complex<T>> {
268        debug_assert!(self.is_normalized());
269        debug_assert_eq!(self.degree_raw(), 1);
270
271        self.trim();
272        if self.degree_raw() < 1 {
273            return vec![];
274        }
275
276        let a = self.0[1].clone();
277        let b = self.0[0].clone();
278
279        // we found all the roots
280        *self = Self::one();
281
282        vec![-b / a]
283    }
284
285    /// Quadratic formula
286    fn quadratic_roots(&mut self) -> Vec<Complex<T>> {
287        debug_assert!(self.is_normalized());
288        debug_assert_eq!(self.degree_raw(), 2);
289
290        // trimming trailing almost zeros to avoid overflow
291        self.trim();
292        if self.degree_raw() == 1 {
293            return self.linear_roots();
294        }
295        if self.degree_raw() == 0 {
296            return vec![];
297        }
298
299        let a = self.0[2].clone();
300        let b = self.0[1].clone();
301        let c = self.0[0].clone();
302        let four = Complex::<T>::from_u8(4).expect("overflow");
303        let two = Complex::<T>::from_u8(2).expect("overflow");
304
305        // TODO: switch to different formula when b^2 and 4c are very close due
306        //       to loss of precision
307        let plus_minus_term = c_sqrt(b.clone() * b.clone() - four * a.clone() * c);
308        let x1 = (plus_minus_term.clone() - b.clone()) / (two.clone() * a.clone());
309        let x2 = (c_neg(b.clone()) - plus_minus_term) / (two * a);
310
311        // we found all the roots
312        *self = Self::one();
313
314        vec![x1, x2]
315    }
316}
317
318/// Find roots that are within a given tolerance from each other and group them
319fn group_multiples<T: RealScalar>(roots: Roots<T>, epsilon: T) -> Vec<Roots<T>> {
320    // groups with their respective mean
321    let mut groups: Vec<(Roots<T>, Complex<T>)> = vec![];
322
323    let mut roots = roots;
324
325    while roots.len() > 0 {
326        // now for each root we find a group whose median is within tolerance,
327        // if we don't find any we add a new group with the one root
328        // if we do, we add the point, update the mean
329        'roots_loop: for root in roots.drain(..) {
330            for group in &mut groups {
331                if (group.1.clone() - root.clone()).norm_sqr() <= epsilon {
332                    group.0.push(root.clone());
333                    group.1 = slice_mean(&group.0);
334                    continue 'roots_loop;
335                }
336            }
337            groups.push((vec![root.clone()], root));
338        }
339
340        // now we loop through all the groups and through each element in each group
341        // and remove any elements that now lay outside of the tolerance
342        for group in &mut groups {
343            // hijacking retain to avoid having to write a loop where we delete
344            // things from the collection we're iterating from.
345            group.0.retain(|r| {
346                if (r.clone() - group.1.clone()).norm_sqr() <= epsilon {
347                    true
348                } else {
349                    roots.push(r.clone());
350                    false
351                }
352            })
353        }
354
355        // finally we prune empty groups
356        groups.retain(|g| g.0.len() > 0);
357    }
358
359    groups.into_iter().map(|(r, _)| r).collect_vec()
360}
361
362fn best_multiples<T: RealScalar>(
363    poly: &Poly<T>,
364    groups: Vec<Roots<T>>,
365    do_broadcast: bool,
366) -> Roots<T> {
367    // find the best root in each group
368    groups
369        .into_iter()
370        .map(|group| {
371            let len = group.len();
372            let best = group
373                .into_iter()
374                .map(|root| (root.clone(), poly.eval(root).norm_sqr()))
375                .reduce(|(a_root, a_eval), (b_root, b_eval)| {
376                    if a_eval < b_eval {
377                        (a_root, a_eval)
378                    } else {
379                        (b_root, b_eval)
380                    }
381                })
382                .expect("empty groups not allowed")
383                .0;
384            if do_broadcast {
385                vec![best; len]
386            } else {
387                vec![best]
388            }
389        })
390        .flatten()
391        .collect_vec()
392}
393
394fn average_multiples<T: RealScalar>(
395    poly: &Poly<T>,
396    groups: Vec<Roots<T>>,
397    do_broadcast: bool,
398) -> Roots<T> {
399    groups
400        .into_iter()
401        .map(|group| {
402            let len_usize = group.len();
403            debug_assert!(len_usize > 0);
404            let len = T::from_usize(len_usize).expect("infallible");
405            let sum: Complex<T> = group.into_iter().sum();
406            let avg = sum / len;
407            if do_broadcast {
408                vec![avg; len_usize]
409            } else {
410                vec![avg]
411            }
412        })
413        .flatten()
414        .collect_vec()
415}
416
417#[cfg(test)]
418mod test {
419    use num::complex::ComplexFloat;
420
421    use crate::Poly64;
422
423    /// See [#3](https://github.com/PanieriLorenzo/rust-poly/issues/3)
424    #[test]
425    fn roots_of_reverse_bessel() {
426        let poly = Poly64::reverse_bessel(2).unwrap();
427        let roots = poly.roots(1E-10, 1000).unwrap();
428        assert!((roots[0].re() - -1.5).abs() < 0.01);
429        assert!((roots[0].im().abs() - 0.866).abs() < 0.01);
430        assert!((roots[1].re() - -1.5).abs() < 0.01);
431        assert!((roots[1].im().abs() - 0.866).abs() < 0.01);
432    }
433}