sparse_ir/
gauss.rs

1//! Gauss quadrature rules for numerical integration
2//!
3//! This module provides quadrature rules for approximating integrals by weighted sums.
4//!
5//! The integral of f(x) * omega(x) is approximated by a weighted sum:
6//!
7//! sum(f(xi) * wi for (xi, wi) in zip(x, w))
8//!
9//! where we generally have superexponential convergence for smooth f(x)
10//! with the number of quadrature points.
11
12use crate::numeric::CustomNumeric;
13use simba::scalar::ComplexField;
14use std::fmt::Debug;
15
16/// Quadrature rule for numerical integration.
17///
18/// Represents an approximation of an integral by a weighted sum over discrete points.
19/// The rule contains quadrature points `x`, weights `w`, and auxiliary arrays
20/// `x_forward` and `x_backward` for efficient computation.
21#[derive(Debug, Clone)]
22pub struct Rule<T> {
23    /// Quadrature points
24    pub x: Vec<T>, //COMMENT: ADD CHECK CODE TO MAKE SURE x is in non-decreasing order
25    /// Quadrature weights
26    pub w: Vec<T>,
27    /// Distance from left endpoint: x - a
28    pub x_forward: Vec<T>,
29    /// Distance from right endpoint: b - x
30    pub x_backward: Vec<T>,
31    /// Left endpoint of integration interval
32    pub a: T,
33    /// Right endpoint of integration interval
34    pub b: T,
35}
36
37impl<T> Rule<T>
38where
39    T: CustomNumeric,
40{
41    /// Create a new quadrature rule from points and weights.
42    ///
43    /// # Arguments
44    /// * `x` - Quadrature points
45    /// * `w` - Quadrature weights
46    /// * `a` - Left endpoint (default: -1.0)
47    /// * `b` - Right endpoint (default: 1.0)
48    ///
49    /// # Panics
50    /// Panics if x and w have different lengths.
51    pub fn new(x: Vec<T>, w: Vec<T>, a: T, b: T) -> Self {
52        assert_eq!(x.len(), w.len(), "x and w must have the same length");
53
54        let x_forward: Vec<T> = x.iter().map(|&xi| xi - a).collect();
55        let x_backward: Vec<T> = x.iter().map(|&xi| b - xi).collect();
56
57        Self {
58            x,
59            w,
60            x_forward,
61            x_backward,
62            a,
63            b,
64        }
65    }
66
67    /// Create a new quadrature rule from vectors.
68    pub fn from_vectors(x: Vec<T>, w: Vec<T>, a: T, b: T) -> Self {
69        Self::new(x, w, a, b)
70    }
71
72    /// Create a default rule with empty arrays.
73    pub fn empty() -> Self {
74        Self {
75            x: vec![],
76            w: vec![],
77            x_forward: vec![],
78            x_backward: vec![],
79            a: <T as CustomNumeric>::from_f64_unchecked(-1.0),
80            b: <T as CustomNumeric>::from_f64_unchecked(1.0),
81        }
82    }
83
84    /// Reseat the rule to a new interval [a, b].
85    ///
86    /// Scales and translates the quadrature points and weights to the new interval.
87    pub fn reseat(&self, a: T, b: T) -> Self {
88        let scaling = (b - a) / (self.b - self.a);
89        let midpoint_old = (self.b + self.a) * <T as CustomNumeric>::from_f64_unchecked(0.5);
90        let midpoint_new = (b + a) * <T as CustomNumeric>::from_f64_unchecked(0.5);
91
92        // Transform x: scaling * (xi - midpoint_old) + midpoint_new
93        let new_x: Vec<T> = self
94            .x
95            .iter()
96            .map(|&xi| scaling * (xi - midpoint_old) + midpoint_new)
97            .collect();
98        let new_w: Vec<T> = self.w.iter().map(|&wi| wi * scaling).collect();
99        let new_x_forward: Vec<T> = self.x_forward.iter().map(|&xi| xi * scaling).collect();
100        let new_x_backward: Vec<T> = self.x_backward.iter().map(|&xi| xi * scaling).collect();
101
102        Self {
103            x: new_x,
104            w: new_w,
105            x_forward: new_x_forward,
106            x_backward: new_x_backward,
107            a,
108            b,
109        }
110    }
111
112    /// Scale the weights by a factor.
113    pub fn scale(&self, factor: T) -> Self {
114        Self {
115            x: self.x.clone(),
116            w: self.w.iter().map(|&wi| wi * factor).collect(),
117            x_forward: self.x_forward.clone(),
118            x_backward: self.x_backward.clone(),
119            a: self.a,
120            b: self.b,
121        }
122    }
123
124    /// Create a piecewise rule over multiple segments.
125    ///
126    /// # Arguments
127    /// * `edges` - Segment boundaries (must be sorted in ascending order)
128    ///
129    /// # Panics
130    /// Panics if edges are not sorted or have less than 2 elements.
131    pub fn piecewise(&self, edges: &[T]) -> Self {
132        if edges.len() < 2 {
133            panic!("edges must have at least 2 elements");
134        }
135
136        // Check if edges are sorted
137        for i in 1..edges.len() {
138            if edges[i] <= edges[i - 1] {
139                panic!("edges must be sorted in ascending order");
140            }
141        }
142
143        let mut rules = Vec::new();
144        for i in 0..edges.len() - 1 {
145            let rule = self.reseat(edges[i], edges[i + 1]);
146            rules.push(rule);
147        }
148
149        Self::join(&rules)
150    }
151
152    /// Join multiple rules into a single rule.
153    ///
154    /// # Arguments
155    /// * `rules` - Vector of rules to join (must be contiguous and sorted)
156    ///
157    /// # Panics
158    /// Panics if rules are empty, not contiguous, or not sorted.
159    pub fn join(rules: &[Self]) -> Self {
160        if rules.is_empty() {
161            return Self::empty();
162        }
163
164        let a = rules[0].a;
165        let b = rules[rules.len() - 1].b;
166
167        // Check that rules are contiguous
168        for i in 1..rules.len() {
169            if (rules[i].a - rules[i - 1].b).abs_as_same_type() > T::epsilon() {
170                panic!("rules must be contiguous");
171            }
172        }
173
174        // Concatenate all arrays
175        let mut x_vec = Vec::new();
176        let mut w_vec = Vec::new();
177        let mut x_forward_vec = Vec::new();
178        let mut x_backward_vec = Vec::new();
179
180        for rule in rules {
181            // Adjust x_forward and x_backward for global coordinates
182            let x_forward_adj: Vec<T> =
183                rule.x_forward.iter().map(|&xi| xi + (rule.a - a)).collect();
184            let x_backward_adj: Vec<T> = rule
185                .x_backward
186                .iter()
187                .map(|&xi| xi + (b - rule.b))
188                .collect();
189
190            x_vec.extend(rule.x.iter().cloned());
191            w_vec.extend(rule.w.iter().cloned());
192            x_forward_vec.extend(x_forward_adj.iter().cloned());
193            x_backward_vec.extend(x_backward_adj.iter().cloned());
194        }
195
196        // Sort by x values to maintain order
197        let mut indices: Vec<usize> = (0..x_vec.len()).collect();
198        indices.sort_by(|&a, &b| x_vec[a].partial_cmp(&x_vec[b]).unwrap());
199
200        let sorted_x: Vec<T> = indices.iter().map(|&i| x_vec[i]).collect();
201        let sorted_w: Vec<T> = indices.iter().map(|&i| w_vec[i]).collect();
202
203        // Recalculate x_forward and x_backward after sorting
204        let sorted_x_forward: Vec<T> = sorted_x.iter().map(|&xi| xi - a).collect();
205        let sorted_x_backward: Vec<T> = sorted_x.iter().map(|&xi| b - xi).collect();
206
207        Self {
208            x: sorted_x,
209            w: sorted_w,
210            x_forward: sorted_x_forward,
211            x_backward: sorted_x_backward,
212            a,
213            b,
214        }
215    }
216
217    /// Convert the rule to a different numeric type.
218    pub fn convert<U>(&self) -> Rule<U>
219    where
220        U: CustomNumeric + Copy + Debug + std::fmt::Display,
221    {
222        let x: Vec<U> = self
223            .x
224            .iter()
225            .map(|&xi| <U as CustomNumeric>::from_f64_unchecked(xi.to_f64()))
226            .collect();
227        let w: Vec<U> = self
228            .w
229            .iter()
230            .map(|&wi| <U as CustomNumeric>::from_f64_unchecked(wi.to_f64()))
231            .collect();
232        let x_forward: Vec<U> = self
233            .x_forward
234            .iter()
235            .map(|&xi| <U as CustomNumeric>::from_f64_unchecked(xi.to_f64()))
236            .collect();
237        let x_backward: Vec<U> = self
238            .x_backward
239            .iter()
240            .map(|&xi| <U as CustomNumeric>::from_f64_unchecked(xi.to_f64()))
241            .collect();
242        let a = <U as CustomNumeric>::from_f64_unchecked(self.a.to_f64());
243        let b = <U as CustomNumeric>::from_f64_unchecked(self.b.to_f64());
244
245        Rule {
246            x,
247            w,
248            x_forward,
249            x_backward,
250            a,
251            b,
252        }
253    }
254
255    /// Validate the rule for consistency.
256    ///
257    /// # Returns
258    /// `true` if the rule is valid, `false` otherwise.
259    pub fn validate(&self) -> bool {
260        // Check interval validity
261        if self.a >= self.b {
262            return false;
263        }
264
265        // Check array lengths
266        if self.x.len() != self.w.len() {
267            return false;
268        }
269
270        if self.x.len() != self.x_forward.len() || self.x.len() != self.x_backward.len() {
271            return false;
272        }
273
274        // Check that all points are within [a, b]
275        for &xi in self.x.iter() {
276            if xi < self.a || xi > self.b {
277                return false;
278            }
279        }
280
281        // Check that points are sorted
282        for i in 1..self.x.len() {
283            if self.x[i] <= self.x[i - 1] {
284                return false;
285            }
286        }
287
288        // Check x_forward and x_backward consistency
289        for i in 0..self.x.len() {
290            let expected_forward = self.x[i] - self.a;
291            let expected_backward = self.b - self.x[i];
292
293            if (self.x_forward[i] - expected_forward).abs_as_same_type() > T::epsilon() {
294                return false;
295            }
296            if (self.x_backward[i] - expected_backward).abs_as_same_type() > T::epsilon() {
297                return false;
298            }
299        }
300
301        true
302    }
303}
304
305/// CustomNumeric-based implementation for f64 and Df64 support
306impl<T> Rule<T>
307where
308    T: CustomNumeric,
309{
310    /// Create a new quadrature rule from points and weights (CustomNumeric version).
311    pub fn new_custom(x: Vec<T>, w: Vec<T>, a: T, b: T) -> Self {
312        assert_eq!(x.len(), w.len(), "x and w must have the same length");
313
314        let x_forward: Vec<T> = x.iter().map(|&xi| xi - a).collect();
315        let x_backward: Vec<T> = x.iter().map(|&xi| b - xi).collect();
316
317        Self {
318            x,
319            w,
320            x_forward,
321            x_backward,
322            a,
323            b,
324        }
325    }
326
327    /// Create a new quadrature rule from vectors (CustomNumeric version).
328    pub fn from_vectors_custom(x: Vec<T>, w: Vec<T>, a: T, b: T) -> Self {
329        Self::new_custom(x, w, a, b)
330    }
331
332    /// Reseat the rule to a new interval [a, b] (CustomNumeric version).
333    pub fn reseat_custom(&self, a: T, b: T) -> Self {
334        let scaling = (b - a) / (self.b - self.a);
335        let midpoint_old = (self.b + self.a) * <T as CustomNumeric>::from_f64_unchecked(0.5);
336        let midpoint_new = (b + a) * <T as CustomNumeric>::from_f64_unchecked(0.5);
337
338        // Transform x: scaling * (xi - midpoint_old) + midpoint_new
339        let new_x: Vec<T> = self
340            .x
341            .iter()
342            .map(|&xi| scaling * (xi - midpoint_old) + midpoint_new)
343            .collect();
344        let new_w: Vec<T> = self.w.iter().map(|&wi| wi * scaling).collect();
345        let new_x_forward: Vec<T> = self.x_forward.iter().map(|&xi| xi * scaling).collect();
346        let new_x_backward: Vec<T> = self.x_backward.iter().map(|&xi| xi * scaling).collect();
347
348        Self {
349            x: new_x,
350            w: new_w,
351            x_forward: new_x_forward,
352            x_backward: new_x_backward,
353            a,
354            b,
355        }
356    }
357
358    /// Scale the weights by a factor (CustomNumeric version).
359    pub fn scale_custom(&self, factor: T) -> Self {
360        Self {
361            x: self.x.clone(),
362            w: self.w.iter().map(|&wi| wi * factor).collect(),
363            x_forward: self.x_forward.clone(),
364            x_backward: self.x_backward.clone(),
365            a: self.a,
366            b: self.b,
367        }
368    }
369
370    /// Validate the rule for consistency (CustomNumeric version).
371    pub fn validate_custom(&self) -> bool {
372        // Check interval validity
373        if self.a >= self.b {
374            return false;
375        }
376
377        // Check array lengths
378        if self.x.len() != self.w.len() {
379            return false;
380        }
381
382        if self.x.len() != self.x_forward.len() || self.x.len() != self.x_backward.len() {
383            return false;
384        }
385
386        // Check that all points are within [a, b]
387        for &xi in self.x.iter() {
388            if xi < self.a || xi > self.b {
389                return false;
390            }
391        }
392
393        // Check that points are sorted
394        for i in 1..self.x.len() {
395            if self.x[i] <= self.x[i - 1] {
396                return false;
397            }
398        }
399
400        // Check x_forward and x_backward consistency
401        for i in 0..self.x.len() {
402            let expected_forward = self.x[i] - self.a;
403            let expected_backward = self.b - self.x[i];
404
405            if (self.x_forward[i] - expected_forward).abs_as_same_type() > T::epsilon() {
406                return false;
407            }
408            if (self.x_backward[i] - expected_backward).abs_as_same_type() > T::epsilon() {
409                return false;
410            }
411        }
412
413        true
414    }
415}
416
417/// Df64-specific implementation without ScalarOperand requirement
418impl Rule<crate::Df64> {
419    /// Create a new quadrature rule from points and weights (Df64 version).
420    pub fn new_twofloat(
421        x: Vec<crate::Df64>,
422        w: Vec<crate::Df64>,
423        a: crate::Df64,
424        b: crate::Df64,
425    ) -> Self {
426        assert_eq!(x.len(), w.len(), "x and w must have the same length");
427
428        let x_forward: Vec<crate::Df64> = x.iter().map(|&xi| xi - a).collect();
429        let x_backward: Vec<crate::Df64> = x.iter().map(|&xi| b - xi).collect();
430
431        Self {
432            x,
433            w,
434            x_forward,
435            x_backward,
436            a,
437            b,
438        }
439    }
440
441    /// Create a new quadrature rule from vectors (Df64 version).
442    pub fn from_vectors_twofloat(
443        x: Vec<crate::Df64>,
444        w: Vec<crate::Df64>,
445        a: crate::Df64,
446        b: crate::Df64,
447    ) -> Self {
448        Self::new_twofloat(x, w, a, b)
449    }
450
451    /// Reseat the rule to a new interval [a, b] (Df64 version).
452    pub fn reseat_twofloat(&self, a: crate::Df64, b: crate::Df64) -> Self {
453        let scaling = (b - a) / (self.b - self.a);
454        let midpoint_old =
455            (self.b + self.a) * <crate::Df64 as CustomNumeric>::from_f64_unchecked(0.5);
456        let midpoint_new = (b + a) * <crate::Df64 as CustomNumeric>::from_f64_unchecked(0.5);
457
458        // Transform x: scaling * (xi - midpoint_old) + midpoint_new
459        let new_x: Vec<crate::Df64> = self
460            .x
461            .iter()
462            .map(|&xi| scaling * (xi - midpoint_old) + midpoint_new)
463            .collect();
464        let new_w: Vec<crate::Df64> = self.w.iter().map(|&wi| wi * scaling).collect();
465        let new_x_forward: Vec<crate::Df64> =
466            self.x_forward.iter().map(|&xi| xi * scaling).collect();
467        let new_x_backward: Vec<crate::Df64> =
468            self.x_backward.iter().map(|&xi| xi * scaling).collect();
469
470        Self {
471            x: new_x,
472            w: new_w,
473            x_forward: new_x_forward,
474            x_backward: new_x_backward,
475            a,
476            b,
477        }
478    }
479
480    /// Scale the weights by a factor (Df64 version).
481    pub fn scale_twofloat(&self, factor: crate::Df64) -> Self {
482        Self {
483            x: self.x.clone(),
484            w: self.w.iter().map(|&wi| wi * factor).collect(),
485            x_forward: self.x_forward.clone(),
486            x_backward: self.x_backward.clone(),
487            a: self.a,
488            b: self.b,
489        }
490    }
491
492    /// Validate the rule for consistency (Df64 version).
493    pub fn validate_twofloat(&self) -> bool {
494        // Check interval validity
495        if self.a >= self.b {
496            return false;
497        }
498
499        // Check array lengths
500        if self.x.len() != self.w.len() {
501            return false;
502        }
503
504        if self.x.len() != self.x_forward.len() || self.x.len() != self.x_backward.len() {
505            return false;
506        }
507
508        // Check that all points are within [a, b]
509        for &xi in self.x.iter() {
510            if xi < self.a || xi > self.b {
511                return false;
512            }
513        }
514
515        // Check that points are sorted
516        for i in 1..self.x.len() {
517            if self.x[i] <= self.x[i - 1] {
518                return false;
519            }
520        }
521
522        // Check x_forward and x_backward consistency
523        for i in 0..self.x.len() {
524            let expected_forward = self.x[i] - self.a;
525            let expected_backward = self.b - self.x[i];
526
527            if (self.x_forward[i] - expected_forward).abs() > crate::Df64::epsilon() {
528                return false;
529            }
530            if (self.x_backward[i] - expected_backward).abs() > crate::Df64::epsilon() {
531                return false;
532            }
533        }
534
535        true
536    }
537}
538
539/// Compute Gauss-Legendre quadrature nodes and weights using Newton's method.
540///
541/// This is a simplified implementation of the Gauss-Legendre quadrature rule.
542/// For production use, a more sophisticated algorithm would be preferred.
543fn gauss_legendre_nodes_weights<T>(n: usize) -> (Vec<T>, Vec<T>)
544where
545    T: CustomNumeric + Copy + Debug + std::fmt::Display + 'static,
546{
547    if n == 0 {
548        return (Vec::new(), Vec::new());
549    }
550
551    if n == 1 {
552        return (
553            vec![<T as CustomNumeric>::from_f64_unchecked(0.0)],
554            vec![<T as CustomNumeric>::from_f64_unchecked(2.0)],
555        );
556    }
557
558    let mut x = Vec::with_capacity(n);
559    let mut w = Vec::with_capacity(n);
560
561    // Use Newton's method to find roots of Legendre polynomial
562    let m = n.div_ceil(2);
563
564    // Use high-precision constants via CustomNumeric trait
565    let pi = T::pi();
566
567    for i in 0..m {
568        // Convert integers directly to avoid f64 intermediate
569        let i_val = <T as CustomNumeric>::from_f64_unchecked(i as f64);
570        let n_val = <T as CustomNumeric>::from_f64_unchecked(n as f64);
571        let three_quarters = <T as CustomNumeric>::from_f64_unchecked(0.75);
572        let half = <T as CustomNumeric>::from_f64_unchecked(0.5);
573
574        // Initial guess using Chebyshev nodes
575        let mut z = (pi * (i_val + three_quarters) / (n_val + half)).cos();
576
577        // Newton's method to refine the root
578        for _ in 0..10 {
579            let (p0, p1) = legendre_polynomial_and_derivative(n, z);
580            if p0.abs_as_same_type() < T::epsilon() {
581                break;
582            }
583            z = z - p0 / p1;
584        }
585
586        // Compute weight using high-precision constants
587        let two = <T as CustomNumeric>::from_f64_unchecked(2.0);
588        let one = <T as CustomNumeric>::from_f64_unchecked(1.0);
589        let (_, p1) = legendre_polynomial_and_derivative(n, z);
590        let weight = two / ((one - z * z) * p1 * p1);
591
592        x.push(-z);
593        w.push(weight);
594
595        if i != n - 1 - i {
596            x.push(z);
597            w.push(weight);
598        }
599    }
600
601    // Sort by x values
602    let mut indices: Vec<usize> = (0..n).collect();
603    indices.sort_by(|&a, &b| x[a].partial_cmp(&x[b]).unwrap());
604
605    let sorted_x: Vec<T> = indices.iter().map(|&i| x[i]).collect();
606    let sorted_w: Vec<T> = indices.iter().map(|&i| w[i]).collect();
607
608    (sorted_x, sorted_w)
609}
610
611/// Compute Legendre polynomial P_n(x) and its derivative using recurrence relation.
612fn legendre_polynomial_and_derivative<T>(n: usize, x: T) -> (T, T)
613where
614    T: CustomNumeric + Copy + Debug + std::fmt::Display + 'static,
615{
616    if n == 0 {
617        return (
618            <T as CustomNumeric>::from_f64_unchecked(1.0),
619            <T as CustomNumeric>::from_f64_unchecked(0.0),
620        );
621    }
622
623    if n == 1 {
624        return (x, <T as CustomNumeric>::from_f64_unchecked(1.0));
625    }
626
627    let mut p0 = <T as CustomNumeric>::from_f64_unchecked(1.0);
628    let mut p1 = x;
629    let mut dp0 = <T as CustomNumeric>::from_f64_unchecked(0.0);
630    let mut dp1 = <T as CustomNumeric>::from_f64_unchecked(1.0);
631
632    for k in 2..=n {
633        let k_f = <T as CustomNumeric>::from_f64_unchecked(k as f64);
634        let k1_f = <T as CustomNumeric>::from_f64_unchecked((k - 1) as f64);
635        let _k2_f = <T as CustomNumeric>::from_f64_unchecked((k - 2) as f64);
636
637        let p2 = ((<T as CustomNumeric>::from_f64_unchecked(2.0) * k1_f
638            + <T as CustomNumeric>::from_f64_unchecked(1.0))
639            * x
640            * p1
641            - k1_f * p0)
642            / k_f;
643        let dp2 = ((<T as CustomNumeric>::from_f64_unchecked(2.0) * k1_f
644            + <T as CustomNumeric>::from_f64_unchecked(1.0))
645            * (p1 + x * dp1)
646            - k1_f * dp0)
647            / k_f;
648
649        p0 = p1;
650        p1 = p2;
651        dp0 = dp1;
652        dp1 = dp2;
653    }
654
655    (p1, dp1)
656}
657
658/// Create a Gauss-Legendre quadrature rule with n points on [-1, 1].
659///
660/// # Arguments
661/// * `n` - Number of quadrature points
662///
663/// # Returns
664/// A Gauss-Legendre quadrature rule
665pub fn legendre<T>(n: usize) -> Rule<T>
666where
667    T: CustomNumeric + Copy + Debug + std::fmt::Display + 'static,
668{
669    if n == 0 {
670        return Rule::empty();
671    }
672
673    let (x, w) = gauss_legendre_nodes_weights(n);
674
675    Rule::from_vectors(
676        x,
677        w,
678        <T as CustomNumeric>::from_f64_unchecked(-1.0),
679        <T as CustomNumeric>::from_f64_unchecked(1.0),
680    )
681}
682
683/// Compute Gauss-Legendre quadrature nodes and weights using CustomNumeric
684fn gauss_legendre_nodes_weights_custom<T>(n: usize) -> (Vec<T>, Vec<T>)
685where
686    T: CustomNumeric,
687{
688    if n == 0 {
689        return (Vec::new(), Vec::new());
690    }
691
692    if n == 1 {
693        return (
694            vec![<T as CustomNumeric>::from_f64_unchecked(0.0)],
695            vec![<T as CustomNumeric>::from_f64_unchecked(2.0)],
696        );
697    }
698
699    let mut x = Vec::with_capacity(n);
700    let mut w = Vec::with_capacity(n);
701
702    // Use Newton's method to find roots of Legendre polynomial
703    let m = n.div_ceil(2);
704    let pi = <T as CustomNumeric>::from_f64_unchecked(std::f64::consts::PI);
705
706    for i in 0..m {
707        // Initial guess using Chebyshev nodes
708        // Note: Df64's cos() has only f64-level precision (~15-16 digits), not the full
709        // theoretical 30-digit precision. This limits Df64 interpolation accuracy to ~1e-16,
710        // not the 1e-30 that might be theoretically possible with perfect double-double arithmetic.
711        let mut z = (pi * <T as CustomNumeric>::from_f64_unchecked(i as f64 + 0.75)
712            / <T as CustomNumeric>::from_f64_unchecked(n as f64 + 0.5))
713        .cos();
714
715        // Newton's method to refine the root
716        for _ in 0..10 {
717            let (p0, p1) = legendre_polynomial_and_derivative_custom(n, z);
718            if p0.abs_as_same_type() < T::epsilon() {
719                break;
720            }
721            z = z - p0 / p1;
722        }
723
724        // Compute weight
725        let (_, p1) = legendre_polynomial_and_derivative_custom(n, z);
726        let weight = <T as CustomNumeric>::from_f64_unchecked(2.0)
727            / ((<T as CustomNumeric>::from_f64_unchecked(1.0) - z * z) * p1 * p1);
728
729        x.push(-z);
730        w.push(weight);
731
732        if i != n - 1 - i {
733            x.push(z);
734            w.push(weight);
735        }
736    }
737
738    // Sort by x values
739    let mut indices: Vec<usize> = (0..n).collect();
740    indices.sort_by(|&a, &b| x[a].partial_cmp(&x[b]).unwrap());
741
742    let sorted_x: Vec<T> = indices.iter().map(|&i| x[i]).collect();
743    let sorted_w: Vec<T> = indices.iter().map(|&i| w[i]).collect();
744
745    (sorted_x, sorted_w)
746}
747
748/// Compute Legendre polynomial P_n(x) and its derivative using CustomNumeric
749fn legendre_polynomial_and_derivative_custom<T>(n: usize, x: T) -> (T, T)
750where
751    T: CustomNumeric,
752{
753    if n == 0 {
754        return (
755            <T as CustomNumeric>::from_f64_unchecked(1.0),
756            <T as CustomNumeric>::from_f64_unchecked(0.0),
757        );
758    }
759
760    if n == 1 {
761        return (x, <T as CustomNumeric>::from_f64_unchecked(1.0));
762    }
763
764    let mut p0 = <T as CustomNumeric>::from_f64_unchecked(1.0);
765    let mut p1 = x;
766    let mut dp0 = <T as CustomNumeric>::from_f64_unchecked(0.0);
767    let mut dp1 = <T as CustomNumeric>::from_f64_unchecked(1.0);
768
769    for k in 2..=n {
770        let k_f = <T as CustomNumeric>::from_f64_unchecked(k as f64);
771        let k1_f = <T as CustomNumeric>::from_f64_unchecked((k - 1) as f64);
772        let _k2_f = <T as CustomNumeric>::from_f64_unchecked((k - 2) as f64);
773
774        let two = <T as CustomNumeric>::from_f64_unchecked(2.0);
775        let one = <T as CustomNumeric>::from_f64_unchecked(1.0);
776
777        let p2 = ((two * k1_f + one) * x * p1 - k1_f * p0) / k_f;
778        let dp2 = ((two * k1_f + one) * (p1 + x * dp1) - k1_f * dp0) / k_f;
779
780        p0 = p1;
781        p1 = p2;
782        dp0 = dp1;
783        dp1 = dp2;
784    }
785
786    (p1, dp1)
787}
788
789/// Create a Gauss-Legendre quadrature rule with n points on [-1, 1] (CustomNumeric version).
790pub fn legendre_custom<T>(n: usize) -> Rule<T>
791where
792    T: CustomNumeric,
793{
794    if n == 0 {
795        return Rule::new_custom(
796            vec![],
797            vec![],
798            <T as CustomNumeric>::from_f64_unchecked(-1.0),
799            <T as CustomNumeric>::from_f64_unchecked(1.0),
800        );
801    }
802
803    let (x, w) = gauss_legendre_nodes_weights_custom(n);
804
805    Rule::from_vectors_custom(
806        x,
807        w,
808        <T as CustomNumeric>::from_f64_unchecked(-1.0),
809        <T as CustomNumeric>::from_f64_unchecked(1.0),
810    )
811}
812
813/// Create a Gauss-Legendre quadrature rule with n points on [-1, 1] (Df64 version).
814pub fn legendre_twofloat(n: usize) -> Rule<crate::Df64> {
815    if n == 0 {
816        return Rule::new_twofloat(
817            vec![],
818            vec![],
819            <crate::Df64 as CustomNumeric>::from_f64_unchecked(-1.0),
820            <crate::Df64 as CustomNumeric>::from_f64_unchecked(1.0),
821        );
822    }
823
824    let mut x: Vec<crate::Df64> = vec![crate::Df64::ZERO; n];
825    let mut w: Vec<crate::Df64> = vec![crate::Df64::ZERO; n];
826    xprec::gauss::gauss_legendre(&mut x, &mut w);
827
828    Rule::from_vectors_twofloat(
829        x,
830        w,
831        <crate::Df64 as CustomNumeric>::from_f64_unchecked(-1.0),
832        <crate::Df64 as CustomNumeric>::from_f64_unchecked(1.0),
833    )
834}
835
836/// Create Legendre Vandermonde matrix for polynomial interpolation
837///
838/// # Arguments
839/// * `x` - Points where polynomials are evaluated
840/// * `degree` - Maximum degree of Legendre polynomials
841///
842/// # Returns
843/// Matrix V where V[i,j] = P_j(x_i), with P_j being the j-th Legendre polynomial
844pub fn legendre_vandermonde<T: CustomNumeric>(x: &[T], degree: usize) -> mdarray::DTensor<T, 2> {
845    use mdarray::DTensor;
846
847    let n = x.len();
848    let mut v = DTensor::<T, 2>::from_elem([n, degree + 1], T::zero());
849
850    // First column is all ones (P_0(x) = 1)
851    for i in 0..n {
852        v[[i, 0]] = T::from_f64_unchecked(1.0);
853    }
854
855    // Second column is x (P_1(x) = x)
856    if degree > 0 {
857        for i in 0..n {
858            v[[i, 1]] = x[i];
859        }
860    }
861
862    // Recurrence relation: P_n(x) = ((2n-1)x*P_{n-1}(x) - (n-1)*P_{n-2}(x)) / n
863    for j in 2..=degree {
864        for i in 0..n {
865            let n_f64 = j as f64;
866            let term1 = T::from_f64_unchecked(2.0 * n_f64 - 1.0) * x[i] * v[[i, j - 1]];
867            let term2 = T::from_f64_unchecked(n_f64 - 1.0) * v[[i, j - 2]];
868            v[[i, j]] = (term1 - term2) / T::from_f64_unchecked(n_f64);
869        }
870    }
871
872    v
873}
874
875/// Generic Legendre Gauss quadrature rule for CustomNumeric types
876pub fn legendre_generic<T: CustomNumeric + 'static>(n: usize) -> Rule<T> {
877    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
878        // For f64, use the existing legendre function
879        let rule_f64 = legendre::<f64>(n);
880        Rule::new(
881            rule_f64
882                .x
883                .iter()
884                .map(|&x| T::from_f64_unchecked(x))
885                .collect(),
886            rule_f64
887                .w
888                .iter()
889                .map(|&w| T::from_f64_unchecked(w))
890                .collect(),
891            T::from_f64_unchecked(rule_f64.a),
892            T::from_f64_unchecked(rule_f64.b),
893        )
894    } else {
895        // For Df64, use legendre_twofloat
896        let rule_tf = legendre_twofloat(n);
897        Rule::new(
898            rule_tf.x.iter().map(|&x| T::convert_from(x)).collect(),
899            rule_tf.w.iter().map(|&w| T::convert_from(w)).collect(),
900            T::from_f64_unchecked(rule_tf.a.into()),
901            T::from_f64_unchecked(rule_tf.b.into()),
902        )
903    }
904}
905
906#[cfg(test)]
907#[path = "gauss_tests.rs"]
908mod tests;