1use crate::{
2 util::{
3 complex::{c_from_f128, c_neg, c_sqrt, c_to_f128},
4 vec::slice_mean,
5 },
6 Poly, RealScalar,
7};
8use anyhow::anyhow;
9use f128::f128;
10use itertools::Itertools;
11use num::{Complex, FromPrimitive, One, Zero};
12
13mod single_root;
14pub use single_root::{halley, naive, newton};
15mod all_roots;
16pub use all_roots::{aberth_ehrlich, deflate, halley_deflate, naive_deflate, newton_deflate};
17mod many_roots;
18pub use many_roots::{halley_parallel, naive_parallel, newton_parallel, parallel};
19mod initial_guess;
20pub use initial_guess::{initial_guess_smallest, initial_guesses_circle};
21
22#[derive(thiserror::Error, Debug)]
23#[non_exhaustive]
24pub enum Error<T> {
25 #[error("root finder did not converge within the given constraints")]
26 NoConverge(T),
27
28 #[error("unexpected error while running root finder")]
29 Other(#[from] anyhow::Error),
30}
31
32pub type Result<T> = std::result::Result<Vec<Complex<T>>, Error<Vec<Complex<T>>>>;
35
36pub type Roots<T> = Vec<Complex<T>>;
38
39pub enum PolishingMode<T> {
40 None,
41 StandardPrecision {
42 epsilon: T,
43 min_iter: usize,
44 max_iter: usize,
45 },
46 #[cfg(target_arch = "x86_64")]
47 HighPrecision {
48 epsilon: T,
49 min_iter: usize,
50 max_iter: usize,
51 },
52}
53
54pub enum MultiplesHandlingMode<T> {
55 None,
56 BroadcastBest { detection_epsilon: T },
57 BroadcastAverage { detection_epsilon: T },
58 KeepBest { detection_epsilon: T },
59 KeepAverage { detection_epsilon: T },
60}
61
62pub enum InitialGuessMode<T> {
63 GuessPoolOnly,
64 RandomAnnulus { bias: T, perturbation: T, seed: u64 },
65 }
68
69impl<T: RealScalar> Poly<T> {
70 pub fn roots(&self, epsilon: T, max_iter: usize) -> Result<T> {
80 self.roots_expert(
81 epsilon.clone(),
82 max_iter,
83 0,
84 PolishingMode::StandardPrecision {
85 epsilon: epsilon.clone(),
86 min_iter: 0,
87 max_iter,
88 },
89 MultiplesHandlingMode::BroadcastBest {
90 detection_epsilon: epsilon * T::from_f64(1.5).expect("overflow"),
92 },
93 &[],
94 InitialGuessMode::RandomAnnulus {
95 bias: T::from_f64(0.5).expect("overflow"),
96 perturbation: T::from_f64(0.5).expect("overflow"),
97 seed: 1,
98 },
99 )
100 }
101
102 pub fn roots_expert(
114 &self,
115 epsilon: T,
116 max_iter: usize,
117 _min_iter: usize,
118 polishing_mode: PolishingMode<T>,
119 multiples_handling_mode: MultiplesHandlingMode<T>,
120 initial_guess_pool: &[Complex<T>],
121 initial_guess_mode: InitialGuessMode<T>,
122 ) -> Result<T> {
123 debug_assert!(self.is_normalized());
124
125 let mut this = self.clone();
126
127 let mut roots: Vec<Complex<T>> = this.zero_roots(epsilon.clone());
128
129 match this.degree_raw() {
130 1 => {
131 roots.extend(this.linear_roots());
132 return Ok(roots);
133 }
134 2 => {
135 roots.extend(this.quadratic_roots());
136 return Ok(roots);
137 }
138 _ => {}
139 }
140
141 this.make_monic();
142
143 debug_assert!(this.is_normalized());
144 let mut initial_guesses = Vec::with_capacity(this.degree_raw());
145 for guess in initial_guess_pool.iter().cloned() {
146 initial_guesses.push(guess);
147 }
148
149 let delta = this.degree_raw() - initial_guesses.len();
152 for _ in 0..delta {
153 initial_guesses.push(Complex::<T>::zero());
154 }
155 let mut remaining_guesses_view =
156 &mut initial_guesses[initial_guess_pool.len()..this.degree_raw()];
157
158 match initial_guess_mode {
159 InitialGuessMode::GuessPoolOnly => {
160 if initial_guess_pool.len() < this.degree_raw() {
161 return Err(Error::Other(anyhow!("not enough initial guesses, you must provide one guess per root when using GuessPoolOnly")));
162 }
163 }
164 InitialGuessMode::RandomAnnulus {
165 bias,
166 perturbation,
167 seed,
168 } => {
169 initial_guesses_circle(
170 &this,
171 bias,
172 seed,
173 perturbation,
174 &mut remaining_guesses_view,
175 );
176 } }
179
180 log::trace!("{initial_guesses:?}");
181
182 roots.extend(aberth_ehrlich(
183 &mut this,
184 Some(epsilon.clone()),
185 Some(max_iter),
186 &initial_guesses,
187 )?);
188
189 let roots: Roots<T> = match polishing_mode {
191 PolishingMode::None => Ok(roots),
192 PolishingMode::StandardPrecision {
193 epsilon,
194 min_iter,
195 max_iter,
196 } => newton_parallel(&mut this, Some(epsilon), Some(max_iter), &roots),
197
198 #[cfg(target_arch = "x86_64")]
199 PolishingMode::HighPrecision {
200 epsilon,
201 min_iter,
202 max_iter,
203 } => {
204 let mut this = this.clone().cast_to_f128();
205 let roots = roots.iter().cloned().map(|z| c_to_f128(z)).collect_vec();
206 newton_parallel(
207 &mut this,
208 Some(f128::from(epsilon.to_f64().expect("overflow"))),
209 Some(max_iter),
210 &roots,
211 )
212 .map(|v| v.into_iter().map(|z| c_from_f128::<T>(z)).collect_vec())
213 .map_err(|e| match e {
214 Error::NoConverge(v) => {
215 Error::NoConverge(v.into_iter().map(|z| c_from_f128::<T>(z)).collect_vec())
216 }
217 Error::Other(o) => Error::Other(o),
218 })
219 }
220 }?;
221
222 match multiples_handling_mode {
223 MultiplesHandlingMode::None => Ok(roots),
224 MultiplesHandlingMode::BroadcastBest { detection_epsilon } => Ok(best_multiples(
225 &this,
226 group_multiples(roots, detection_epsilon),
227 true,
228 )),
229 MultiplesHandlingMode::BroadcastAverage { detection_epsilon } => Ok(average_multiples(
230 &this,
231 group_multiples(roots, detection_epsilon),
232 true,
233 )),
234 MultiplesHandlingMode::KeepBest { detection_epsilon } => Ok(best_multiples(
235 &this,
236 group_multiples(roots, detection_epsilon),
237 false,
238 )),
239 MultiplesHandlingMode::KeepAverage { detection_epsilon } => Ok(average_multiples(
240 &this,
241 group_multiples(roots, detection_epsilon),
242 false,
243 )),
244 }
245 }
246}
247
248impl<T: RealScalar> Poly<T> {
250 fn zero_roots(&mut self, epsilon: T) -> Vec<Complex<T>> {
251 debug_assert!(self.is_normalized());
252
253 let mut roots = vec![];
254 for _ in 0..self.degree_raw() {
255 if self.eval(Complex::zero()).norm_sqr() < epsilon {
256 roots.push(Complex::zero());
257 *self = self.shift_down(1);
259 } else {
260 break;
261 }
262 }
263
264 roots
265 }
266
267 fn linear_roots(&mut self) -> Vec<Complex<T>> {
268 debug_assert!(self.is_normalized());
269 debug_assert_eq!(self.degree_raw(), 1);
270
271 self.trim();
272 if self.degree_raw() < 1 {
273 return vec![];
274 }
275
276 let a = self.0[1].clone();
277 let b = self.0[0].clone();
278
279 *self = Self::one();
281
282 vec![-b / a]
283 }
284
285 fn quadratic_roots(&mut self) -> Vec<Complex<T>> {
287 debug_assert!(self.is_normalized());
288 debug_assert_eq!(self.degree_raw(), 2);
289
290 self.trim();
292 if self.degree_raw() == 1 {
293 return self.linear_roots();
294 }
295 if self.degree_raw() == 0 {
296 return vec![];
297 }
298
299 let a = self.0[2].clone();
300 let b = self.0[1].clone();
301 let c = self.0[0].clone();
302 let four = Complex::<T>::from_u8(4).expect("overflow");
303 let two = Complex::<T>::from_u8(2).expect("overflow");
304
305 let plus_minus_term = c_sqrt(b.clone() * b.clone() - four * a.clone() * c);
308 let x1 = (plus_minus_term.clone() - b.clone()) / (two.clone() * a.clone());
309 let x2 = (c_neg(b.clone()) - plus_minus_term) / (two * a);
310
311 *self = Self::one();
313
314 vec![x1, x2]
315 }
316}
317
318fn group_multiples<T: RealScalar>(roots: Roots<T>, epsilon: T) -> Vec<Roots<T>> {
320 let mut groups: Vec<(Roots<T>, Complex<T>)> = vec![];
322
323 let mut roots = roots;
324
325 while roots.len() > 0 {
326 'roots_loop: for root in roots.drain(..) {
330 for group in &mut groups {
331 if (group.1.clone() - root.clone()).norm_sqr() <= epsilon {
332 group.0.push(root.clone());
333 group.1 = slice_mean(&group.0);
334 continue 'roots_loop;
335 }
336 }
337 groups.push((vec![root.clone()], root));
338 }
339
340 for group in &mut groups {
343 group.0.retain(|r| {
346 if (r.clone() - group.1.clone()).norm_sqr() <= epsilon {
347 true
348 } else {
349 roots.push(r.clone());
350 false
351 }
352 })
353 }
354
355 groups.retain(|g| g.0.len() > 0);
357 }
358
359 groups.into_iter().map(|(r, _)| r).collect_vec()
360}
361
362fn best_multiples<T: RealScalar>(
363 poly: &Poly<T>,
364 groups: Vec<Roots<T>>,
365 do_broadcast: bool,
366) -> Roots<T> {
367 groups
369 .into_iter()
370 .map(|group| {
371 let len = group.len();
372 let best = group
373 .into_iter()
374 .map(|root| (root.clone(), poly.eval(root).norm_sqr()))
375 .reduce(|(a_root, a_eval), (b_root, b_eval)| {
376 if a_eval < b_eval {
377 (a_root, a_eval)
378 } else {
379 (b_root, b_eval)
380 }
381 })
382 .expect("empty groups not allowed")
383 .0;
384 if do_broadcast {
385 vec![best; len]
386 } else {
387 vec![best]
388 }
389 })
390 .flatten()
391 .collect_vec()
392}
393
394fn average_multiples<T: RealScalar>(
395 poly: &Poly<T>,
396 groups: Vec<Roots<T>>,
397 do_broadcast: bool,
398) -> Roots<T> {
399 groups
400 .into_iter()
401 .map(|group| {
402 let len_usize = group.len();
403 debug_assert!(len_usize > 0);
404 let len = T::from_usize(len_usize).expect("infallible");
405 let sum: Complex<T> = group.into_iter().sum();
406 let avg = sum / len;
407 if do_broadcast {
408 vec![avg; len_usize]
409 } else {
410 vec![avg]
411 }
412 })
413 .flatten()
414 .collect_vec()
415}
416
417#[cfg(test)]
418mod test {
419 use num::complex::ComplexFloat;
420
421 use crate::Poly64;
422
423 #[test]
425 fn roots_of_reverse_bessel() {
426 let poly = Poly64::reverse_bessel(2).unwrap();
427 let roots = poly.roots(1E-10, 1000).unwrap();
428 assert!((roots[0].re() - -1.5).abs() < 0.01);
429 assert!((roots[0].im().abs() - 0.866).abs() < 0.01);
430 assert!((roots[1].re() - -1.5).abs() < 0.01);
431 assert!((roots[1].im().abs() - 0.866).abs() < 0.01);
432 }
433}