sparse_ir/
poly.rs

1//! Piecewise Legendre polynomial implementations for SparseIR
2//!
3//! This module provides high-performance piecewise Legendre polynomial
4//! functionality compatible with the C++ implementation.
5
6/// A single piecewise Legendre polynomial
7#[derive(Debug, Clone)]
8pub struct PiecewiseLegendrePoly {
9    /// Polynomial order (degree of Legendre polynomials in each segment)
10    pub polyorder: usize,
11    /// Minimum x value of the domain
12    pub xmin: f64,
13    /// Maximum x value of the domain
14    pub xmax: f64,
15    /// Knot points defining the segments
16    pub knots: Vec<f64>,
17    /// Segment widths (for numerical stability)
18    pub delta_x: Vec<f64>,
19    /// Coefficient matrix: [degree][segment_index]
20    pub data: mdarray::DTensor<f64, 2>,
21    /// Symmetry parameter
22    pub symm: i32,
23    /// Polynomial parameter (used in power moments calculation)
24    pub l: i32,
25    /// Segment midpoints
26    pub xm: Vec<f64>,
27    /// Inverse segment widths
28    pub inv_xs: Vec<f64>,
29    /// Normalization factors
30    pub norms: Vec<f64>,
31}
32
33impl PiecewiseLegendrePoly {
34    /// Create a new PiecewiseLegendrePoly from data and knots
35    pub fn new(
36        data: mdarray::DTensor<f64, 2>,
37        knots: Vec<f64>,
38        l: i32,
39        delta_x: Option<Vec<f64>>,
40        symm: i32,
41    ) -> Self {
42        let polyorder = data.shape().0;
43        let nsegments = data.shape().1;
44
45        if knots.len() != nsegments + 1 {
46            panic!(
47                "Invalid knots array: expected {} knots, got {}",
48                nsegments + 1,
49                knots.len()
50            );
51        }
52
53        // Validate knots are sorted
54        for i in 1..knots.len() {
55            if knots[i] <= knots[i - 1] {
56                panic!("Knots must be monotonically increasing");
57            }
58        }
59
60        // Compute delta_x if not provided
61        let delta_x =
62            delta_x.unwrap_or_else(|| (1..knots.len()).map(|i| knots[i] - knots[i - 1]).collect());
63
64        // Validate delta_x matches knots
65        for i in 0..delta_x.len() {
66            let expected = knots[i + 1] - knots[i];
67            if (delta_x[i] - expected).abs() > 1e-10 {
68                panic!("delta_x must match knots");
69            }
70        }
71
72        // Compute segment midpoints
73        let xm: Vec<f64> = (0..nsegments)
74            .map(|i| 0.5 * (knots[i] + knots[i + 1]))
75            .collect();
76
77        // Compute inverse segment widths
78        let inv_xs: Vec<f64> = delta_x.iter().map(|&dx| 2.0 / dx).collect();
79
80        // Compute normalization factors
81        let norms: Vec<f64> = inv_xs.iter().map(|&inv_x| inv_x.sqrt()).collect();
82
83        Self {
84            polyorder,
85            xmin: knots[0],
86            xmax: knots[knots.len() - 1],
87            knots,
88            delta_x,
89            data,
90            symm,
91            l,
92            xm,
93            inv_xs,
94            norms,
95        }
96    }
97
98    /// Create a new PiecewiseLegendrePoly with new data but same structure
99    pub fn with_data(&self, new_data: mdarray::DTensor<f64, 2>) -> Self {
100        Self {
101            data: new_data,
102            ..self.clone()
103        }
104    }
105
106    /// Get the symmetry parameter
107    pub fn symm(&self) -> i32 {
108        self.symm
109    }
110
111    /// Create a new PiecewiseLegendrePoly with new data and symmetry
112    pub fn with_data_and_symmetry(
113        &self,
114        new_data: mdarray::DTensor<f64, 2>,
115        new_symm: i32,
116    ) -> Self {
117        Self {
118            data: new_data,
119            symm: new_symm,
120            ..self.clone()
121        }
122    }
123
124    /// Rescale domain: create a new polynomial with the same data but different knots
125    ///
126    /// This is useful for transforming from one domain to another, e.g.,
127    /// from x ∈ [-1, 1] to τ ∈ [0, β].
128    ///
129    /// # Arguments
130    ///
131    /// * `new_knots` - New knot points
132    /// * `new_delta_x` - Optional new segment widths (computed from knots if None)
133    /// * `new_symm` - Optional new symmetry parameter (keeps old if None)
134    ///
135    /// # Returns
136    ///
137    /// New polynomial with rescaled domain
138    pub fn rescale_domain(
139        &self,
140        new_knots: Vec<f64>,
141        new_delta_x: Option<Vec<f64>>,
142        new_symm: Option<i32>,
143    ) -> Self {
144        Self::new(
145            self.data.clone(),
146            new_knots,
147            self.l,
148            new_delta_x,
149            new_symm.unwrap_or(self.symm),
150        )
151    }
152
153    /// Scale all data values by a constant factor
154    ///
155    /// This is useful for normalizations, e.g., multiplying by √β for
156    /// Fourier transform preparations.
157    ///
158    /// # Arguments
159    ///
160    /// * `factor` - Scaling factor to multiply all data by
161    ///
162    /// # Returns
163    ///
164    /// New polynomial with scaled data
165    pub fn scale_data(&self, factor: f64) -> Self {
166        Self::with_data(
167            self,
168            mdarray::DTensor::<f64, 2>::from_fn(*self.data.shape(), |idx| self.data[idx] * factor),
169        )
170    }
171
172    /// Evaluate the polynomial at a given point
173    pub fn evaluate(&self, x: f64) -> f64 {
174        let (i, x_tilde) = self.split(x);
175        // Extract column i into a Vec
176        let coeffs: Vec<f64> = (0..self.data.shape().0)
177            .map(|row| self.data[[row, i]])
178            .collect();
179        let value = self.evaluate_legendre_polynomial(x_tilde, &coeffs);
180        value * self.norms[i]
181    }
182
183    /// Evaluate the polynomial at multiple points
184    pub fn evaluate_many(&self, xs: &[f64]) -> Vec<f64> {
185        xs.iter().map(|&x| self.evaluate(x)).collect()
186    }
187
188    /// Split x into segment index and normalized x
189    pub fn split(&self, x: f64) -> (usize, f64) {
190        if x < self.xmin || x > self.xmax {
191            panic!("x = {} is outside domain [{}, {}]", x, self.xmin, self.xmax);
192        }
193
194        // Find the segment containing x
195        for i in 0..self.knots.len() - 1 {
196            if x >= self.knots[i] && x <= self.knots[i + 1] {
197                // Transform x to [-1, 1] for Legendre polynomials
198                let x_tilde = 2.0 * (x - self.xm[i]) / self.delta_x[i];
199                return (i, x_tilde);
200            }
201        }
202
203        // Handle edge case: x exactly at the last knot
204        let last_idx = self.knots.len() - 2;
205        let x_tilde = 2.0 * (x - self.xm[last_idx]) / self.delta_x[last_idx];
206        (last_idx, x_tilde)
207    }
208
209    /// Evaluate Legendre polynomial using recurrence relation
210    pub fn evaluate_legendre_polynomial(&self, x: f64, coeffs: &[f64]) -> f64 {
211        if coeffs.is_empty() {
212            return 0.0;
213        }
214
215        let mut result = 0.0;
216        let mut p_prev = 1.0; // P_0(x) = 1
217        let mut p_curr = x; // P_1(x) = x
218
219        // Add first two terms
220        if !coeffs.is_empty() {
221            result += coeffs[0] * p_prev;
222        }
223        if coeffs.len() > 1 {
224            result += coeffs[1] * p_curr;
225        }
226
227        // Use recurrence relation: P_{n+1}(x) = ((2n+1)x*P_n(x) - n*P_{n-1}(x))/(n+1)
228        for n in 1..coeffs.len() - 1 {
229            let p_next =
230                ((2.0 * (n as f64) + 1.0) * x * p_curr - (n as f64) * p_prev) / ((n + 1) as f64);
231            result += coeffs[n + 1] * p_next;
232            p_prev = p_curr;
233            p_curr = p_next;
234        }
235
236        result
237    }
238
239    /// Compute derivative of the polynomial
240    pub fn deriv(&self, n: usize) -> Self {
241        if n == 0 {
242            return self.clone();
243        }
244
245        // Compute derivative coefficients
246        let mut ddata = self.data.clone();
247        for _ in 0..n {
248            ddata = self.compute_derivative_coefficients(&ddata);
249        }
250
251        // Apply scaling factors (C++: ddata.col(i) *= std::pow(inv_xs[i], n))
252        let ddata_shape = *ddata.shape();
253        for i in 0..ddata_shape.1 {
254            let inv_x_power = self.inv_xs[i].powi(n as i32);
255            for j in 0..ddata_shape.0 {
256                ddata[[j, i]] *= inv_x_power;
257            }
258        }
259
260        // Update symmetry: C++: int new_symm = std::pow(-1, n) * symm;
261        let new_symm = if n % 2 == 0 { self.symm } else { -self.symm };
262
263        Self {
264            data: ddata,
265            symm: new_symm,
266            ..self.clone()
267        }
268    }
269
270    /// Compute derivative coefficients using the same algorithm as C++ legder function
271    fn compute_derivative_coefficients(
272        &self,
273        coeffs: &mdarray::DTensor<f64, 2>,
274    ) -> mdarray::DTensor<f64, 2> {
275        let mut c = coeffs.clone();
276        let c_shape = *c.shape();
277        let mut n = c_shape.0;
278
279        // Single derivative step (equivalent to C++ legder with cnt=1)
280        if n <= 1 {
281            return mdarray::DTensor::<f64, 2>::from_elem([1, c.shape().1], 0.0);
282        }
283
284        n -= 1;
285        let mut der = mdarray::DTensor::<f64, 2>::from_elem([n, c.shape().1], 0.0);
286
287        // C++ implementation: for (int j = n; j >= 2; --j)
288        for j in (2..=n).rev() {
289            // C++: der.row(j - 1) = (2 * j - 1) * c.row(j);
290            for col in 0..c_shape.1 {
291                der[[j - 1, col]] = (2.0 * (j as f64) - 1.0) * c[[j, col]];
292            }
293            // C++: c.row(j - 2) += c.row(j);
294            for col in 0..c_shape.1 {
295                c[[j - 2, col]] += c[[j, col]];
296            }
297        }
298
299        // C++: if (n > 1) der.row(1) = 3 * c.row(2);
300        if n > 1 {
301            for col in 0..c_shape.1 {
302                der[[1, col]] = 3.0 * c[[2, col]];
303            }
304        }
305
306        // C++: der.row(0) = c.row(1);
307        for col in 0..c_shape.1 {
308            der[[0, col]] = c[[1, col]];
309        }
310
311        der
312    }
313
314    /// Compute derivatives at a point x
315    pub fn derivs(&self, x: f64) -> Vec<f64> {
316        let mut results = Vec::new();
317
318        // Compute up to polyorder derivatives
319        for n in 0..self.polyorder {
320            let deriv_poly = self.deriv(n);
321            results.push(deriv_poly.evaluate(x));
322        }
323
324        results
325    }
326
327    /// Compute overlap integral with a function
328    pub fn overlap<F>(&self, f: F) -> f64
329    where
330        F: Fn(f64) -> f64,
331    {
332        let mut integral = 0.0;
333
334        for i in 0..self.knots.len() - 1 {
335            let segment_integral =
336                self.gauss_legendre_quadrature(self.knots[i], self.knots[i + 1], |x| {
337                    self.evaluate(x) * f(x)
338                });
339            integral += segment_integral;
340        }
341
342        integral
343    }
344
345    /// Gauss-Legendre quadrature over [a, b]
346    fn gauss_legendre_quadrature<F>(&self, a: f64, b: f64, f: F) -> f64
347    where
348        F: Fn(f64) -> f64,
349    {
350        // 5-point Gauss-Legendre quadrature
351        const XG: [f64; 5] = [
352            -0.906179845938664,
353            -0.538469310105683,
354            0.0,
355            0.538469310105683,
356            0.906179845938664,
357        ];
358        const WG: [f64; 5] = [
359            0.236926885056189,
360            0.478628670499366,
361            0.568888888888889,
362            0.478628670499366,
363            0.236926885056189,
364        ];
365
366        let c1 = (b - a) / 2.0;
367        let c2 = (b + a) / 2.0;
368
369        let mut integral = 0.0;
370        for j in 0..5 {
371            let x = c1 * XG[j] + c2;
372            integral += WG[j] * f(x);
373        }
374
375        integral * c1
376    }
377
378    /// Find roots of the polynomial using C++ compatible algorithm
379    pub fn roots(&self) -> Vec<f64> {
380        // Refine the grid by factor of 4 for better root finding
381        // (C++ uses 2, but RegularizedBoseKernel needs finer resolution)
382        let refined_grid = self.refine_grid(&self.knots, 4);
383
384        // Find all roots using the refined grid
385        self.find_all_roots(&refined_grid)
386    }
387
388    /// Refine grid by factor alpha (C++ compatible)
389    fn refine_grid(&self, grid: &[f64], alpha: usize) -> Vec<f64> {
390        let mut refined = Vec::new();
391
392        for i in 0..grid.len() - 1 {
393            let start = grid[i];
394            let step = (grid[i + 1] - grid[i]) / (alpha as f64);
395            for j in 0..alpha {
396                refined.push(start + (j as f64) * step);
397            }
398        }
399        refined.push(grid[grid.len() - 1]);
400        refined
401    }
402
403    /// Find all roots using refined grid (C++ compatible)
404    fn find_all_roots(&self, xgrid: &[f64]) -> Vec<f64> {
405        if xgrid.is_empty() {
406            return Vec::new();
407        }
408
409        // Evaluate function at all grid points
410        let fx: Vec<f64> = xgrid.iter().map(|&x| self.evaluate(x)).collect();
411
412        // Find exact zeros (direct hits)
413        let mut x_hit = Vec::new();
414        for i in 0..fx.len() {
415            if fx[i] == 0.0 {
416                x_hit.push(xgrid[i]);
417            }
418        }
419
420        // Find sign changes
421        let mut sign_change = Vec::new();
422        for i in 0..fx.len() - 1 {
423            let has_sign_change = fx[i].signum() != fx[i + 1].signum();
424            let not_hit = fx[i] != 0.0 && fx[i + 1] != 0.0;
425            sign_change.push(has_sign_change && not_hit);
426        }
427
428        // If no sign changes, return only direct hits
429        if sign_change.iter().all(|&sc| !sc) {
430            x_hit.sort_by(|a, b| a.partial_cmp(b).unwrap());
431            return x_hit;
432        }
433
434        // Find intervals with sign changes
435        let mut a_intervals = Vec::new();
436        let mut b_intervals = Vec::new();
437        let mut fa_values = Vec::new();
438
439        for i in 0..sign_change.len() {
440            if sign_change[i] {
441                a_intervals.push(xgrid[i]);
442                b_intervals.push(xgrid[i + 1]);
443                fa_values.push(fx[i]);
444            }
445        }
446
447        // Calculate epsilon for convergence
448        let max_elm = xgrid.iter().map(|&x| x.abs()).fold(0.0, f64::max);
449        let epsilon_x = f64::EPSILON * max_elm;
450
451        // Use bisection for each interval with sign change
452        for i in 0..a_intervals.len() {
453            let root = self.bisect(a_intervals[i], b_intervals[i], fa_values[i], epsilon_x);
454            x_hit.push(root);
455        }
456
457        // Sort and return
458        x_hit.sort_by(|a, b| a.partial_cmp(b).unwrap());
459        x_hit
460    }
461
462    /// Bisection method to find root (C++ compatible)
463    fn bisect(&self, a: f64, b: f64, fa: f64, eps: f64) -> f64 {
464        let mut a = a;
465        let mut b = b;
466        let mut fa = fa;
467
468        loop {
469            let mid = (a + b) / 2.0;
470            if self.close_enough(a, mid, eps) {
471                return mid;
472            }
473
474            let fmid = self.evaluate(mid);
475            if fa.signum() != fmid.signum() {
476                b = mid;
477            } else {
478                a = mid;
479                fa = fmid;
480            }
481        }
482    }
483
484    /// Check if two values are close enough (C++ compatible)
485    fn close_enough(&self, a: f64, b: f64, eps: f64) -> bool {
486        (a - b).abs() <= eps
487    }
488
489    // Accessor methods to match C++ interface
490    pub fn get_xmin(&self) -> f64 {
491        self.xmin
492    }
493    pub fn get_xmax(&self) -> f64 {
494        self.xmax
495    }
496    pub fn get_l(&self) -> i32 {
497        self.l
498    }
499    pub fn get_domain(&self) -> (f64, f64) {
500        (self.xmin, self.xmax)
501    }
502    pub fn get_knots(&self) -> &[f64] {
503        &self.knots
504    }
505    pub fn get_delta_x(&self) -> &[f64] {
506        &self.delta_x
507    }
508    pub fn get_symm(&self) -> i32 {
509        self.symm
510    }
511    pub fn get_data(&self) -> &mdarray::DTensor<f64, 2> {
512        &self.data
513    }
514    pub fn get_norms(&self) -> &[f64] {
515        &self.norms
516    }
517    pub fn get_polyorder(&self) -> usize {
518        self.polyorder
519    }
520}
521
522/// Vector of piecewise Legendre polynomials
523#[derive(Debug, Clone)]
524pub struct PiecewiseLegendrePolyVector {
525    /// Individual polynomials
526    pub polyvec: Vec<PiecewiseLegendrePoly>,
527}
528
529impl PiecewiseLegendrePolyVector {
530    /// Constructor with a vector of PiecewiseLegendrePoly
531    ///
532    /// # Panics
533    /// Panics if the input vector is empty, as empty PiecewiseLegendrePolyVector is not meaningful
534    pub fn new(polyvec: Vec<PiecewiseLegendrePoly>) -> Self {
535        if polyvec.is_empty() {
536            panic!("Cannot create empty PiecewiseLegendrePolyVector");
537        }
538        Self { polyvec }
539    }
540
541    /// Get the polynomials
542    pub fn get_polys(&self) -> &[PiecewiseLegendrePoly] {
543        &self.polyvec
544    }
545
546    /// Constructor with a 3D array, knots, and symmetry vector
547    pub fn from_3d_data(
548        data3d: mdarray::DTensor<f64, 3>,
549        knots: Vec<f64>,
550        symm: Option<Vec<i32>>,
551    ) -> Self {
552        let npolys = data3d.shape().2;
553        let mut polyvec = Vec::with_capacity(npolys);
554
555        if let Some(ref symm_vec) = symm {
556            if symm_vec.len() != npolys {
557                panic!("Sizes of data and symm don't match");
558            }
559        }
560
561        // Compute delta_x from knots
562        let delta_x: Vec<f64> = (1..knots.len()).map(|i| knots[i] - knots[i - 1]).collect();
563
564        for i in 0..npolys {
565            // Extract 2D data for this polynomial
566            let data3d_shape = data3d.shape();
567            let mut data =
568                mdarray::DTensor::<f64, 2>::from_elem([data3d_shape.0, data3d_shape.1], 0.0);
569            for j in 0..data3d_shape.0 {
570                for k in 0..data3d_shape.1 {
571                    data[[j, k]] = data3d[[j, k, i]];
572                }
573            }
574
575            let poly = PiecewiseLegendrePoly::new(
576                data,
577                knots.clone(),
578                i as i32,
579                Some(delta_x.clone()),
580                symm.as_ref().map_or(0, |s| s[i]),
581            );
582
583            polyvec.push(poly);
584        }
585
586        Self { polyvec }
587    }
588
589    /// Get the size of the vector
590    pub fn size(&self) -> usize {
591        self.polyvec.len()
592    }
593
594    /// Rescale domain for all polynomials in the vector
595    ///
596    /// Creates a new PiecewiseLegendrePolyVector where each polynomial has
597    /// the same data but new knots and delta_x.
598    ///
599    /// # Arguments
600    ///
601    /// * `new_knots` - New knot points (same for all polynomials)
602    /// * `new_delta_x` - Optional new segment widths
603    /// * `new_symm` - Optional vector of new symmetry parameters (one per polynomial)
604    ///
605    /// # Returns
606    ///
607    /// New vector with rescaled domains
608    pub fn rescale_domain(
609        &self,
610        new_knots: Vec<f64>,
611        new_delta_x: Option<Vec<f64>>,
612        new_symm: Option<Vec<i32>>,
613    ) -> Self {
614        let polyvec = self
615            .polyvec
616            .iter()
617            .enumerate()
618            .map(|(i, poly)| {
619                let symm = new_symm.as_ref().map(|s| s[i]);
620                poly.rescale_domain(new_knots.clone(), new_delta_x.clone(), symm)
621            })
622            .collect();
623
624        Self { polyvec }
625    }
626
627    /// Scale all data values by a constant factor
628    ///
629    /// Multiplies the data of all polynomials by the same factor.
630    ///
631    /// # Arguments
632    ///
633    /// * `factor` - Scaling factor to multiply all data by
634    ///
635    /// # Returns
636    ///
637    /// New vector with scaled data
638    pub fn scale_data(&self, factor: f64) -> Self {
639        let polyvec = self
640            .polyvec
641            .iter()
642            .map(|poly| poly.scale_data(factor))
643            .collect();
644
645        Self { polyvec }
646    }
647
648    /// Get polynomial by index (immutable)
649    pub fn get(&self, index: usize) -> Option<&PiecewiseLegendrePoly> {
650        self.polyvec.get(index)
651    }
652
653    /// Get polynomial by index (mutable) - deprecated, use immutable design instead
654    #[deprecated(
655        note = "PiecewiseLegendrePolyVector is designed to be immutable. Use get() and create new instances for modifications."
656    )]
657    pub fn get_mut(&mut self, index: usize) -> Option<&mut PiecewiseLegendrePoly> {
658        self.polyvec.get_mut(index)
659    }
660
661    /// Extract a single polynomial as a vector
662    pub fn slice_single(&self, index: usize) -> Option<Self> {
663        self.polyvec.get(index).map(|poly| Self {
664            polyvec: vec![poly.clone()],
665        })
666    }
667
668    /// Extract multiple polynomials by indices
669    pub fn slice_multi(&self, indices: &[usize]) -> Self {
670        // Validate indices
671        for &idx in indices {
672            if idx >= self.polyvec.len() {
673                panic!("Index {} out of range", idx);
674            }
675        }
676
677        // Check for duplicates
678        {
679            let mut unique_indices = indices.to_vec();
680            unique_indices.sort();
681            unique_indices.dedup();
682            if unique_indices.len() != indices.len() {
683                panic!("Duplicate indices not allowed");
684            }
685        }
686
687        let new_polyvec: Vec<_> = indices
688            .iter()
689            .map(|&idx| self.polyvec[idx].clone())
690            .collect();
691
692        Self {
693            polyvec: new_polyvec,
694        }
695    }
696
697    /// Evaluate all polynomials at a single point
698    pub fn evaluate_at(&self, x: f64) -> Vec<f64> {
699        self.polyvec.iter().map(|poly| poly.evaluate(x)).collect()
700    }
701
702    /// Evaluate all polynomials at multiple points
703    pub fn evaluate_at_many(&self, xs: &[f64]) -> mdarray::DTensor<f64, 2> {
704        let n_funcs = self.polyvec.len();
705        let n_points = xs.len();
706        let mut results = mdarray::DTensor::<f64, 2>::from_elem([n_funcs, n_points], 0.0);
707
708        for (i, poly) in self.polyvec.iter().enumerate() {
709            for (j, &x) in xs.iter().enumerate() {
710                results[[i, j]] = poly.evaluate(x);
711            }
712        }
713
714        results
715    }
716
717    // Accessor methods to match C++ interface
718    pub fn xmin(&self) -> f64 {
719        if self.polyvec.is_empty() {
720            panic!("Cannot get xmin from empty PiecewiseLegendrePolyVector");
721        }
722        self.polyvec[0].xmin
723    }
724
725    pub fn xmax(&self) -> f64 {
726        if self.polyvec.is_empty() {
727            panic!("Cannot get xmax from empty PiecewiseLegendrePolyVector");
728        }
729        self.polyvec[0].xmax
730    }
731
732    pub fn get_knots(&self, tolerance: Option<f64>) -> Vec<f64> {
733        if self.polyvec.is_empty() {
734            panic!("Cannot get knots from empty PiecewiseLegendrePolyVector");
735        }
736        const DEFAULT_TOLERANCE: f64 = 1e-10;
737        let tolerance = tolerance.unwrap_or(DEFAULT_TOLERANCE);
738
739        // Collect all knots from all polynomials
740        let mut all_knots = Vec::new();
741        for poly in &self.polyvec {
742            for &knot in &poly.knots {
743                all_knots.push(knot);
744            }
745        }
746
747        // Sort and remove duplicates
748        {
749            all_knots.sort_by(|a, b| a.partial_cmp(b).unwrap());
750            all_knots.dedup_by(|a, b| (*a - *b).abs() < tolerance);
751        }
752        all_knots
753    }
754
755    pub fn get_delta_x(&self) -> Vec<f64> {
756        if self.polyvec.is_empty() {
757            panic!("Cannot get delta_x from empty PiecewiseLegendrePolyVector");
758        }
759        self.polyvec[0].delta_x.clone()
760    }
761
762    pub fn get_polyorder(&self) -> usize {
763        if self.polyvec.is_empty() {
764            panic!("Cannot get polyorder from empty PiecewiseLegendrePolyVector");
765        }
766        self.polyvec[0].polyorder
767    }
768
769    pub fn get_norms(&self) -> &[f64] {
770        if self.polyvec.is_empty() {
771            panic!("Cannot get norms from empty PiecewiseLegendrePolyVector");
772        }
773        &self.polyvec[0].norms
774    }
775
776    pub fn get_symm(&self) -> Vec<i32> {
777        if self.polyvec.is_empty() {
778            panic!("Cannot get symm from empty PiecewiseLegendrePolyVector");
779        }
780        self.polyvec.iter().map(|poly| poly.symm).collect()
781    }
782
783    /// Get data as 3D tensor: [segment][degree][polynomial]
784    pub fn get_data(&self) -> mdarray::DTensor<f64, 3> {
785        if self.polyvec.is_empty() {
786            panic!("Cannot get data from empty PiecewiseLegendrePolyVector");
787        }
788
789        let nsegments = self.polyvec[0].data.shape().1;
790        let polyorder = self.polyvec[0].polyorder;
791        let npolys = self.polyvec.len();
792
793        let mut data = mdarray::DTensor::<f64, 3>::from_elem([nsegments, polyorder, npolys], 0.0);
794
795        for (poly_idx, poly) in self.polyvec.iter().enumerate() {
796            for segment in 0..nsegments {
797                for degree in 0..polyorder {
798                    data[[segment, degree, poly_idx]] = poly.data[[degree, segment]];
799                }
800            }
801        }
802
803        data
804    }
805
806    /// Find roots of all polynomials
807    pub fn roots(&self, tolerance: Option<f64>) -> Vec<f64> {
808        if self.polyvec.is_empty() {
809            panic!("Cannot get roots from empty PiecewiseLegendrePolyVector");
810        }
811        const DEFAULT_TOLERANCE: f64 = 1e-10;
812        let tolerance = tolerance.unwrap_or(DEFAULT_TOLERANCE);
813        let mut all_roots = Vec::new();
814
815        for poly in &self.polyvec {
816            let poly_roots = poly.roots();
817            for root in poly_roots {
818                all_roots.push(root);
819            }
820        }
821
822        // Sort in descending order and remove duplicates (like C++ implementation)
823        {
824            all_roots.sort_by(|a, b| b.partial_cmp(a).unwrap());
825            all_roots.dedup_by(|a, b| (*a - *b).abs() < tolerance);
826        }
827        all_roots
828    }
829
830    /// Get reference to last polynomial
831    ///
832    /// C++ equivalent: u.polyvec.back()
833    pub fn last(&self) -> &PiecewiseLegendrePoly {
834        self.polyvec
835            .last()
836            .expect("Cannot get last from empty PiecewiseLegendrePolyVector")
837    }
838
839    /// Get the number of roots
840    pub fn nroots(&self, tolerance: Option<f64>) -> usize {
841        if self.polyvec.is_empty() {
842            panic!("Cannot get nroots from empty PiecewiseLegendrePolyVector");
843        }
844        self.roots(tolerance).len()
845    }
846}
847
848impl std::ops::Index<usize> for PiecewiseLegendrePolyVector {
849    type Output = PiecewiseLegendrePoly;
850
851    fn index(&self, index: usize) -> &Self::Output {
852        &self.polyvec[index]
853    }
854}
855
856/// Get default sampling points in [-1, 1]
857///
858/// C++ implementation: libsparseir/include/sparseir/basis.hpp:287-310
859///
860/// For orthogonal polynomials (the high-T limit of IR), we know that the
861/// ideal sampling points for a basis of size L are the roots of the L-th
862/// polynomial. We empirically find that these stay good sampling points
863/// for our kernels (probably because the kernels are totally positive).
864///
865/// If we do not have enough polynomials in the basis, we approximate the
866/// roots of the L'th polynomial by the extrema of the last basis function,
867/// which is sensible due to the strong interleaving property of these
868/// functions' roots.
869pub fn default_sampling_points(u: &PiecewiseLegendrePolyVector, l: usize) -> Vec<f64> {
870    // C++: if (u.xmin() != -1.0 || u.xmax() != 1.0)
871    //          throw std::runtime_error("Expecting unscaled functions here.");
872    if (u.xmin() - (-1.0)).abs() > 1e-10 || (u.xmax() - 1.0).abs() > 1e-10 {
873        panic!("Expecting unscaled functions here.");
874    }
875
876    let x0 = if l < u.polyvec.len() {
877        // C++: return u.polyvec[L].roots();
878        u[l].roots()
879    } else {
880        // C++: PiecewiseLegendrePoly poly = u.polyvec.back();
881        //      Eigen::VectorXd maxima = poly.deriv().roots();
882        let poly = u.last();
883        let poly_deriv = poly.deriv(1);
884        let maxima = poly_deriv.roots();
885
886        // C++: double left = (maxima[0] + poly.xmin) / 2.0;
887        let left = (maxima[0] + poly.xmin) / 2.0;
888
889        // C++: double right = (maxima[maxima.size() - 1] + poly.xmax) / 2.0;
890        let right = (maxima[maxima.len() - 1] + poly.xmax) / 2.0;
891
892        // C++: Eigen::VectorXd x0(maxima.size() + 2);
893        //      x0[0] = left;
894        //      x0.segment(1, maxima.size()) = maxima;
895        //      x0[x0.size() - 1] = right;
896        let mut x0_vec = Vec::with_capacity(maxima.len() + 2);
897        x0_vec.push(left);
898        x0_vec.extend_from_slice(&maxima);
899        x0_vec.push(right);
900
901        x0_vec
902    };
903
904    // C++: if (x0.size() != L) { warning }
905    if x0.len() != l {
906        eprintln!(
907            "Warning: Expecting to get {} sampling points for corresponding basis function, \
908             instead got {}. This may happen if not enough precision is left in the polynomial.",
909            l,
910            x0.len()
911        );
912    }
913
914    x0
915}
916
917// IndexMut implementation removed - PiecewiseLegendrePolyVector is designed to be immutable
918// If modification is needed, create a new instance instead
919
920// Note: FnOnce implementation removed due to experimental nature
921// Use evaluate_at() and evaluate_at_many() methods directly
922
923#[cfg(test)]
924#[path = "poly_tests.rs"]
925mod poly_tests;