Skip to main content

sparse_ir/
poly.rs

1//! Piecewise Legendre polynomial implementations for SparseIR
2//!
3//! This module provides high-performance piecewise Legendre polynomial
4//! functionality compatible with the C++ implementation.
5
6/// A single piecewise Legendre polynomial
7#[derive(Debug, Clone)]
8pub struct PiecewiseLegendrePoly {
9    /// Polynomial order (degree of Legendre polynomials in each segment)
10    pub polyorder: usize,
11    /// Minimum x value of the domain
12    pub xmin: f64,
13    /// Maximum x value of the domain
14    pub xmax: f64,
15    /// Knot points defining the segments
16    pub knots: Vec<f64>,
17    /// Segment widths (for numerical stability)
18    pub delta_x: Vec<f64>,
19    /// Coefficient matrix: [degree][segment_index]
20    pub data: mdarray::DTensor<f64, 2>,
21    /// Symmetry parameter
22    pub symm: i32,
23    /// Polynomial parameter (used in power moments calculation)
24    pub l: i32,
25    /// Segment midpoints
26    pub xm: Vec<f64>,
27    /// Inverse segment widths
28    pub inv_xs: Vec<f64>,
29    /// Normalization factors
30    pub norms: Vec<f64>,
31}
32
33impl PiecewiseLegendrePoly {
34    /// Create a new PiecewiseLegendrePoly from data and knots
35    pub fn new(
36        data: mdarray::DTensor<f64, 2>,
37        knots: Vec<f64>,
38        l: i32,
39        delta_x: Option<Vec<f64>>,
40        symm: i32,
41    ) -> Self {
42        let polyorder = data.shape().0;
43        let nsegments = data.shape().1;
44
45        if knots.len() != nsegments + 1 {
46            panic!(
47                "Invalid knots array: expected {} knots, got {}",
48                nsegments + 1,
49                knots.len()
50            );
51        }
52
53        // Validate knots are sorted
54        for i in 1..knots.len() {
55            if knots[i] <= knots[i - 1] {
56                panic!("Knots must be monotonically increasing");
57            }
58        }
59
60        // Compute delta_x if not provided
61        let delta_x =
62            delta_x.unwrap_or_else(|| (1..knots.len()).map(|i| knots[i] - knots[i - 1]).collect());
63
64        // Validate delta_x matches knots
65        for i in 0..delta_x.len() {
66            let expected = knots[i + 1] - knots[i];
67            if (delta_x[i] - expected).abs() > 1e-10 {
68                panic!("delta_x must match knots");
69            }
70        }
71
72        // Compute segment midpoints
73        let xm: Vec<f64> = (0..nsegments)
74            .map(|i| 0.5 * (knots[i] + knots[i + 1]))
75            .collect();
76
77        // Compute inverse segment widths
78        let inv_xs: Vec<f64> = delta_x.iter().map(|&dx| 2.0 / dx).collect();
79
80        // Compute normalization factors
81        let norms: Vec<f64> = inv_xs.iter().map(|&inv_x| inv_x.sqrt()).collect();
82
83        Self {
84            polyorder,
85            xmin: knots[0],
86            xmax: knots[knots.len() - 1],
87            knots,
88            delta_x,
89            data,
90            symm,
91            l,
92            xm,
93            inv_xs,
94            norms,
95        }
96    }
97
98    /// Create a new PiecewiseLegendrePoly with new data but same structure
99    pub fn with_data(&self, new_data: mdarray::DTensor<f64, 2>) -> Self {
100        Self {
101            data: new_data,
102            ..self.clone()
103        }
104    }
105
106    /// Get the symmetry parameter
107    pub fn symm(&self) -> i32 {
108        self.symm
109    }
110
111    /// Create a new PiecewiseLegendrePoly with new data and symmetry
112    pub fn with_data_and_symmetry(
113        &self,
114        new_data: mdarray::DTensor<f64, 2>,
115        new_symm: i32,
116    ) -> Self {
117        Self {
118            data: new_data,
119            symm: new_symm,
120            ..self.clone()
121        }
122    }
123
124    /// Rescale domain: create a new polynomial with the same data but different knots
125    ///
126    /// This is useful for transforming from one domain to another, e.g.,
127    /// from x ∈ [-1, 1] to τ ∈ [0, β].
128    ///
129    /// # Arguments
130    ///
131    /// * `new_knots` - New knot points
132    /// * `new_delta_x` - Optional new segment widths (computed from knots if None)
133    /// * `new_symm` - Optional new symmetry parameter (keeps old if None)
134    ///
135    /// # Returns
136    ///
137    /// New polynomial with rescaled domain
138    pub fn rescale_domain(
139        &self,
140        new_knots: Vec<f64>,
141        new_delta_x: Option<Vec<f64>>,
142        new_symm: Option<i32>,
143    ) -> Self {
144        Self::new(
145            self.data.clone(),
146            new_knots,
147            self.l,
148            new_delta_x,
149            new_symm.unwrap_or(self.symm),
150        )
151    }
152
153    /// Scale all data values by a constant factor
154    ///
155    /// This is useful for normalizations, e.g., multiplying by √β for
156    /// Fourier transform preparations.
157    ///
158    /// # Arguments
159    ///
160    /// * `factor` - Scaling factor to multiply all data by
161    ///
162    /// # Returns
163    ///
164    /// New polynomial with scaled data
165    pub fn scale_data(&self, factor: f64) -> Self {
166        Self::with_data(
167            self,
168            mdarray::DTensor::<f64, 2>::from_fn(*self.data.shape(), |idx| self.data[idx] * factor),
169        )
170    }
171
172    /// Evaluate the polynomial at a given point
173    pub fn evaluate(&self, x: f64) -> f64 {
174        let (i, x_tilde) = self.split(x);
175        // Extract column i into a Vec
176        let coeffs: Vec<f64> = (0..self.data.shape().0)
177            .map(|row| self.data[[row, i]])
178            .collect();
179        let value = self.evaluate_legendre_polynomial(x_tilde, &coeffs);
180        value * self.norms[i]
181    }
182
183    /// Evaluate the polynomial at multiple points
184    pub fn evaluate_many(&self, xs: &[f64]) -> Vec<f64> {
185        xs.iter().map(|&x| self.evaluate(x)).collect()
186    }
187
188    /// Split x into segment index and normalized x
189    pub fn split(&self, x: f64) -> (usize, f64) {
190        if x < self.xmin || x > self.xmax {
191            panic!("x = {} is outside domain [{}, {}]", x, self.xmin, self.xmax);
192        }
193
194        // Find the segment containing x
195        for i in 0..self.knots.len() - 1 {
196            if x >= self.knots[i] && x <= self.knots[i + 1] {
197                // Transform x to [-1, 1] for Legendre polynomials
198                let x_tilde = 2.0 * (x - self.xm[i]) / self.delta_x[i];
199                return (i, x_tilde);
200            }
201        }
202
203        // Handle edge case: x exactly at the last knot
204        let last_idx = self.knots.len() - 2;
205        let x_tilde = 2.0 * (x - self.xm[last_idx]) / self.delta_x[last_idx];
206        (last_idx, x_tilde)
207    }
208
209    /// Evaluate Legendre polynomial using recurrence relation
210    pub fn evaluate_legendre_polynomial(&self, x: f64, coeffs: &[f64]) -> f64 {
211        if coeffs.is_empty() {
212            return 0.0;
213        }
214
215        let mut result = 0.0;
216        let mut p_prev = 1.0; // P_0(x) = 1
217        let mut p_curr = x; // P_1(x) = x
218
219        // Add first two terms
220        if !coeffs.is_empty() {
221            result += coeffs[0] * p_prev;
222        }
223        if coeffs.len() > 1 {
224            result += coeffs[1] * p_curr;
225        }
226
227        // Use recurrence relation: P_{n+1}(x) = ((2n+1)x*P_n(x) - n*P_{n-1}(x))/(n+1)
228        for n in 1..coeffs.len() - 1 {
229            let p_next =
230                ((2.0 * (n as f64) + 1.0) * x * p_curr - (n as f64) * p_prev) / ((n + 1) as f64);
231            result += coeffs[n + 1] * p_next;
232            p_prev = p_curr;
233            p_curr = p_next;
234        }
235
236        result
237    }
238
239    /// Compute derivative of the polynomial
240    pub fn deriv(&self, n: usize) -> Self {
241        if n == 0 {
242            return self.clone();
243        }
244
245        // Compute derivative coefficients
246        let mut ddata = self.data.clone();
247        for _ in 0..n {
248            ddata = self.compute_derivative_coefficients(&ddata);
249        }
250
251        // Apply scaling factors (C++: ddata.col(i) *= std::pow(inv_xs[i], n))
252        let ddata_shape = *ddata.shape();
253        for i in 0..ddata_shape.1 {
254            let inv_x_power = self.inv_xs[i].powi(n as i32);
255            for j in 0..ddata_shape.0 {
256                ddata[[j, i]] *= inv_x_power;
257            }
258        }
259
260        // Update symmetry: C++: int new_symm = std::pow(-1, n) * symm;
261        let new_symm = if n % 2 == 0 { self.symm } else { -self.symm };
262
263        Self {
264            data: ddata,
265            symm: new_symm,
266            ..self.clone()
267        }
268    }
269
270    /// Compute derivative coefficients using the same algorithm as C++ legder function
271    fn compute_derivative_coefficients(
272        &self,
273        coeffs: &mdarray::DTensor<f64, 2>,
274    ) -> mdarray::DTensor<f64, 2> {
275        let mut c = coeffs.clone();
276        let c_shape = *c.shape();
277        let mut n = c_shape.0;
278
279        // Single derivative step (equivalent to C++ legder with cnt=1)
280        if n <= 1 {
281            return mdarray::DTensor::<f64, 2>::from_elem([1, c.shape().1], 0.0);
282        }
283
284        n -= 1;
285        let mut der = mdarray::DTensor::<f64, 2>::from_elem([n, c.shape().1], 0.0);
286
287        // C++ implementation: for (int j = n; j >= 2; --j)
288        for j in (2..=n).rev() {
289            // C++: der.row(j - 1) = (2 * j - 1) * c.row(j);
290            for col in 0..c_shape.1 {
291                der[[j - 1, col]] = (2.0 * (j as f64) - 1.0) * c[[j, col]];
292            }
293            // C++: c.row(j - 2) += c.row(j);
294            for col in 0..c_shape.1 {
295                c[[j - 2, col]] += c[[j, col]];
296            }
297        }
298
299        // C++: if (n > 1) der.row(1) = 3 * c.row(2);
300        if n > 1 {
301            for col in 0..c_shape.1 {
302                der[[1, col]] = 3.0 * c[[2, col]];
303            }
304        }
305
306        // C++: der.row(0) = c.row(1);
307        for col in 0..c_shape.1 {
308            der[[0, col]] = c[[1, col]];
309        }
310
311        der
312    }
313
314    /// Compute derivatives at a point x
315    pub fn derivs(&self, x: f64) -> Vec<f64> {
316        let mut results = Vec::new();
317
318        // Compute up to polyorder derivatives
319        for n in 0..self.polyorder {
320            let deriv_poly = self.deriv(n);
321            results.push(deriv_poly.evaluate(x));
322        }
323
324        results
325    }
326
327    /// Compute overlap integral with a function
328    pub fn overlap<F>(&self, f: F) -> f64
329    where
330        F: Fn(f64) -> f64,
331    {
332        let mut integral = 0.0;
333
334        for i in 0..self.knots.len() - 1 {
335            let segment_integral =
336                self.gauss_legendre_quadrature(self.knots[i], self.knots[i + 1], |x| {
337                    self.evaluate(x) * f(x)
338                });
339            integral += segment_integral;
340        }
341
342        integral
343    }
344
345    /// Gauss-Legendre quadrature over [a, b]
346    fn gauss_legendre_quadrature<F>(&self, a: f64, b: f64, f: F) -> f64
347    where
348        F: Fn(f64) -> f64,
349    {
350        // 5-point Gauss-Legendre quadrature
351        const XG: [f64; 5] = [
352            -0.906179845938664,
353            -0.538469310105683,
354            0.0,
355            0.538469310105683,
356            0.906179845938664,
357        ];
358        const WG: [f64; 5] = [
359            0.236926885056189,
360            0.478628670499366,
361            0.568888888888889,
362            0.478628670499366,
363            0.236926885056189,
364        ];
365
366        let c1 = (b - a) / 2.0;
367        let c2 = (b + a) / 2.0;
368
369        let mut integral = 0.0;
370        for j in 0..5 {
371            let x = c1 * XG[j] + c2;
372            integral += WG[j] * f(x);
373        }
374
375        integral * c1
376    }
377
378    /// Find roots of the polynomial using C++ compatible algorithm
379    pub fn roots(&self) -> Vec<f64> {
380        let xmid = (self.xmax + self.xmin) / 2.0;
381
382        // Exploit symmetry: only search the right half, then mirror.
383        // This matches the Python 1.x / Julia v1 algorithm and guarantees
384        // exactly symmetric root positions.
385        let grid = if self.symm != 0 {
386            let nsegments = self.knots.len() - 1;
387            let mid_idx = nsegments / 2;
388            if (self.knots[mid_idx] - xmid).abs() < 1e-15 {
389                self.knots[mid_idx..].to_vec()
390            } else {
391                let mut g = vec![xmid];
392                g.extend(self.knots.iter().filter(|&&x| x > xmid));
393                g
394            }
395        } else {
396            self.knots.clone()
397        };
398
399        let refined_grid = self.refine_grid(&grid, 4);
400        let roots_half = self.find_all_roots(&refined_grid);
401
402        if self.symm == 1 {
403            // Even symmetry: roots on right half, mirror to left
404            let mut all_roots: Vec<f64> = roots_half
405                .iter()
406                .rev()
407                .map(|&r| (self.xmax + self.xmin) - r)
408                .collect();
409            all_roots.extend_from_slice(&roots_half);
410            all_roots
411        } else if self.symm == -1 {
412            // Odd symmetry: there must be a zero at xmid
413            let mut right = roots_half;
414            if !right.is_empty() {
415                // Remove the root at xmid if found (may be slightly off),
416                // or if f(xmid) and f'(xmid) have opposite signs (spurious zero)
417                let f_mid = self.evaluate(xmid);
418                let f_deriv_mid = self.deriv(1).evaluate(xmid);
419                if (right[0] - xmid).abs() < 1e-13 || f_mid * f_deriv_mid < 0.0 {
420                    right.remove(0);
421                }
422            }
423            let mut all_roots: Vec<f64> = right
424                .iter()
425                .rev()
426                .map(|&r| (self.xmax + self.xmin) - r)
427                .collect();
428            all_roots.push(xmid);
429            all_roots.extend_from_slice(&right);
430            all_roots
431        } else {
432            // No symmetry: search the full domain
433            let full_grid = self.refine_grid(&self.knots, 4);
434            self.find_all_roots(&full_grid)
435        }
436    }
437
438    /// Refine grid by factor alpha (C++ compatible)
439    fn refine_grid(&self, grid: &[f64], alpha: usize) -> Vec<f64> {
440        let mut refined = Vec::new();
441
442        for i in 0..grid.len() - 1 {
443            let start = grid[i];
444            let step = (grid[i + 1] - grid[i]) / (alpha as f64);
445            for j in 0..alpha {
446                refined.push(start + (j as f64) * step);
447            }
448        }
449        refined.push(grid[grid.len() - 1]);
450        refined
451    }
452
453    /// Find all roots using refined grid (C++ compatible)
454    fn find_all_roots(&self, xgrid: &[f64]) -> Vec<f64> {
455        if xgrid.is_empty() {
456            return Vec::new();
457        }
458
459        // Evaluate function at all grid points
460        let fx: Vec<f64> = xgrid.iter().map(|&x| self.evaluate(x)).collect();
461
462        // Find exact zeros (direct hits)
463        let mut x_hit = Vec::new();
464        for i in 0..fx.len() {
465            if fx[i] == 0.0 {
466                x_hit.push(xgrid[i]);
467            }
468        }
469
470        // Find sign changes
471        let mut sign_change = Vec::new();
472        for i in 0..fx.len() - 1 {
473            let has_sign_change = fx[i].signum() != fx[i + 1].signum();
474            let not_hit = fx[i] != 0.0 && fx[i + 1] != 0.0;
475            let sc = has_sign_change && not_hit;
476            sign_change.push(sc);
477        }
478
479        // If no sign changes, return only direct hits
480        if sign_change.iter().all(|&sc| !sc) {
481            x_hit.sort_by(|a, b| a.partial_cmp(b).unwrap());
482            return x_hit;
483        }
484
485        // Find intervals with sign changes
486        let mut a_intervals = Vec::new();
487        let mut b_intervals = Vec::new();
488        let mut fa_values = Vec::new();
489
490        for i in 0..sign_change.len() {
491            if sign_change[i] {
492                a_intervals.push(xgrid[i]);
493                b_intervals.push(xgrid[i + 1]);
494                fa_values.push(fx[i]);
495            }
496        }
497
498        // Calculate epsilon for convergence
499        let max_elm = xgrid.iter().map(|&x| x.abs()).fold(0.0, f64::max);
500        let epsilon_x = f64::EPSILON * max_elm;
501
502        // Use bisection for each interval with sign change
503        for i in 0..a_intervals.len() {
504            let root = self.bisect(a_intervals[i], b_intervals[i], fa_values[i], epsilon_x);
505            x_hit.push(root);
506        }
507
508        // Sort and return
509        x_hit.sort_by(|a, b| a.partial_cmp(b).unwrap());
510        x_hit
511    }
512
513    /// Bisection method to find root (C++ compatible)
514    fn bisect(&self, a: f64, b: f64, fa: f64, eps: f64) -> f64 {
515        let mut a = a;
516        let mut b = b;
517        let mut fa = fa;
518
519        loop {
520            let mid = (a + b) / 2.0;
521            if self.close_enough(a, mid, eps) {
522                return mid;
523            }
524
525            let fmid = self.evaluate(mid);
526            if fa.signum() != fmid.signum() {
527                b = mid;
528            } else {
529                a = mid;
530                fa = fmid;
531            }
532        }
533    }
534
535    /// Check if two values are close enough (C++ compatible)
536    fn close_enough(&self, a: f64, b: f64, eps: f64) -> bool {
537        (a - b).abs() <= eps
538    }
539
540    // Accessor methods to match C++ interface
541    pub fn get_xmin(&self) -> f64 {
542        self.xmin
543    }
544    pub fn get_xmax(&self) -> f64 {
545        self.xmax
546    }
547    pub fn get_l(&self) -> i32 {
548        self.l
549    }
550    pub fn get_domain(&self) -> (f64, f64) {
551        (self.xmin, self.xmax)
552    }
553    pub fn get_knots(&self) -> &[f64] {
554        &self.knots
555    }
556    pub fn get_delta_x(&self) -> &[f64] {
557        &self.delta_x
558    }
559    pub fn get_symm(&self) -> i32 {
560        self.symm
561    }
562    pub fn get_data(&self) -> &mdarray::DTensor<f64, 2> {
563        &self.data
564    }
565    pub fn get_norms(&self) -> &[f64] {
566        &self.norms
567    }
568    pub fn get_polyorder(&self) -> usize {
569        self.polyorder
570    }
571}
572
573/// Vector of piecewise Legendre polynomials
574#[derive(Debug, Clone)]
575pub struct PiecewiseLegendrePolyVector {
576    /// Individual polynomials
577    pub polyvec: Vec<PiecewiseLegendrePoly>,
578}
579
580impl PiecewiseLegendrePolyVector {
581    /// Constructor with a vector of PiecewiseLegendrePoly
582    ///
583    /// # Panics
584    /// Panics if the input vector is empty, as empty PiecewiseLegendrePolyVector is not meaningful
585    pub fn new(polyvec: Vec<PiecewiseLegendrePoly>) -> Self {
586        if polyvec.is_empty() {
587            panic!("Cannot create empty PiecewiseLegendrePolyVector");
588        }
589        Self { polyvec }
590    }
591
592    /// Get the polynomials
593    pub fn get_polys(&self) -> &[PiecewiseLegendrePoly] {
594        &self.polyvec
595    }
596
597    /// Constructor with a 3D array, knots, and symmetry vector
598    pub fn from_3d_data(
599        data3d: mdarray::DTensor<f64, 3>,
600        knots: Vec<f64>,
601        symm: Option<Vec<i32>>,
602    ) -> Self {
603        let npolys = data3d.shape().2;
604        let mut polyvec = Vec::with_capacity(npolys);
605
606        if let Some(ref symm_vec) = symm {
607            if symm_vec.len() != npolys {
608                panic!("Sizes of data and symm don't match");
609            }
610        }
611
612        // Compute delta_x from knots
613        let delta_x: Vec<f64> = (1..knots.len()).map(|i| knots[i] - knots[i - 1]).collect();
614
615        for i in 0..npolys {
616            // Extract 2D data for this polynomial
617            let data3d_shape = data3d.shape();
618            let mut data =
619                mdarray::DTensor::<f64, 2>::from_elem([data3d_shape.0, data3d_shape.1], 0.0);
620            for j in 0..data3d_shape.0 {
621                for k in 0..data3d_shape.1 {
622                    data[[j, k]] = data3d[[j, k, i]];
623                }
624            }
625
626            let poly = PiecewiseLegendrePoly::new(
627                data,
628                knots.clone(),
629                i as i32,
630                Some(delta_x.clone()),
631                symm.as_ref().map_or(0, |s| s[i]),
632            );
633
634            polyvec.push(poly);
635        }
636
637        Self { polyvec }
638    }
639
640    /// Get the size of the vector
641    pub fn size(&self) -> usize {
642        self.polyvec.len()
643    }
644
645    /// Rescale domain for all polynomials in the vector
646    ///
647    /// Creates a new PiecewiseLegendrePolyVector where each polynomial has
648    /// the same data but new knots and delta_x.
649    ///
650    /// # Arguments
651    ///
652    /// * `new_knots` - New knot points (same for all polynomials)
653    /// * `new_delta_x` - Optional new segment widths
654    /// * `new_symm` - Optional vector of new symmetry parameters (one per polynomial)
655    ///
656    /// # Returns
657    ///
658    /// New vector with rescaled domains
659    pub fn rescale_domain(
660        &self,
661        new_knots: Vec<f64>,
662        new_delta_x: Option<Vec<f64>>,
663        new_symm: Option<Vec<i32>>,
664    ) -> Self {
665        let polyvec = self
666            .polyvec
667            .iter()
668            .enumerate()
669            .map(|(i, poly)| {
670                let symm = new_symm.as_ref().map(|s| s[i]);
671                poly.rescale_domain(new_knots.clone(), new_delta_x.clone(), symm)
672            })
673            .collect();
674
675        Self { polyvec }
676    }
677
678    /// Scale all data values by a constant factor
679    ///
680    /// Multiplies the data of all polynomials by the same factor.
681    ///
682    /// # Arguments
683    ///
684    /// * `factor` - Scaling factor to multiply all data by
685    ///
686    /// # Returns
687    ///
688    /// New vector with scaled data
689    pub fn scale_data(&self, factor: f64) -> Self {
690        let polyvec = self
691            .polyvec
692            .iter()
693            .map(|poly| poly.scale_data(factor))
694            .collect();
695
696        Self { polyvec }
697    }
698
699    /// Get polynomial by index (immutable)
700    pub fn get(&self, index: usize) -> Option<&PiecewiseLegendrePoly> {
701        self.polyvec.get(index)
702    }
703
704    /// Get polynomial by index (mutable) - deprecated, use immutable design instead
705    #[deprecated(
706        note = "PiecewiseLegendrePolyVector is designed to be immutable. Use get() and create new instances for modifications."
707    )]
708    pub fn get_mut(&mut self, index: usize) -> Option<&mut PiecewiseLegendrePoly> {
709        self.polyvec.get_mut(index)
710    }
711
712    /// Extract a single polynomial as a vector
713    pub fn slice_single(&self, index: usize) -> Option<Self> {
714        self.polyvec.get(index).map(|poly| Self {
715            polyvec: vec![poly.clone()],
716        })
717    }
718
719    /// Extract multiple polynomials by indices
720    pub fn slice_multi(&self, indices: &[usize]) -> Self {
721        // Validate indices
722        for &idx in indices {
723            if idx >= self.polyvec.len() {
724                panic!("Index {} out of range", idx);
725            }
726        }
727
728        // Check for duplicates
729        {
730            let mut unique_indices = indices.to_vec();
731            unique_indices.sort();
732            unique_indices.dedup();
733            if unique_indices.len() != indices.len() {
734                panic!("Duplicate indices not allowed");
735            }
736        }
737
738        let new_polyvec: Vec<_> = indices
739            .iter()
740            .map(|&idx| self.polyvec[idx].clone())
741            .collect();
742
743        Self {
744            polyvec: new_polyvec,
745        }
746    }
747
748    /// Evaluate all polynomials at a single point
749    pub fn evaluate_at(&self, x: f64) -> Vec<f64> {
750        self.polyvec.iter().map(|poly| poly.evaluate(x)).collect()
751    }
752
753    /// Evaluate all polynomials at multiple points
754    pub fn evaluate_at_many(&self, xs: &[f64]) -> mdarray::DTensor<f64, 2> {
755        let n_funcs = self.polyvec.len();
756        let n_points = xs.len();
757        let mut results = mdarray::DTensor::<f64, 2>::from_elem([n_funcs, n_points], 0.0);
758
759        for (i, poly) in self.polyvec.iter().enumerate() {
760            for (j, &x) in xs.iter().enumerate() {
761                results[[i, j]] = poly.evaluate(x);
762            }
763        }
764
765        results
766    }
767
768    // Accessor methods to match C++ interface
769    pub fn xmin(&self) -> f64 {
770        if self.polyvec.is_empty() {
771            panic!("Cannot get xmin from empty PiecewiseLegendrePolyVector");
772        }
773        self.polyvec[0].xmin
774    }
775
776    pub fn xmax(&self) -> f64 {
777        if self.polyvec.is_empty() {
778            panic!("Cannot get xmax from empty PiecewiseLegendrePolyVector");
779        }
780        self.polyvec[0].xmax
781    }
782
783    pub fn get_knots(&self, tolerance: Option<f64>) -> Vec<f64> {
784        if self.polyvec.is_empty() {
785            panic!("Cannot get knots from empty PiecewiseLegendrePolyVector");
786        }
787        const DEFAULT_TOLERANCE: f64 = 1e-10;
788        let tolerance = tolerance.unwrap_or(DEFAULT_TOLERANCE);
789
790        // Collect all knots from all polynomials
791        let mut all_knots = Vec::new();
792        for poly in &self.polyvec {
793            for &knot in &poly.knots {
794                all_knots.push(knot);
795            }
796        }
797
798        // Sort and remove duplicates
799        {
800            all_knots.sort_by(|a, b| a.partial_cmp(b).unwrap());
801            all_knots.dedup_by(|a, b| (*a - *b).abs() < tolerance);
802        }
803        all_knots
804    }
805
806    pub fn get_delta_x(&self) -> Vec<f64> {
807        if self.polyvec.is_empty() {
808            panic!("Cannot get delta_x from empty PiecewiseLegendrePolyVector");
809        }
810        self.polyvec[0].delta_x.clone()
811    }
812
813    pub fn get_polyorder(&self) -> usize {
814        if self.polyvec.is_empty() {
815            panic!("Cannot get polyorder from empty PiecewiseLegendrePolyVector");
816        }
817        self.polyvec[0].polyorder
818    }
819
820    pub fn get_norms(&self) -> &[f64] {
821        if self.polyvec.is_empty() {
822            panic!("Cannot get norms from empty PiecewiseLegendrePolyVector");
823        }
824        &self.polyvec[0].norms
825    }
826
827    pub fn get_symm(&self) -> Vec<i32> {
828        if self.polyvec.is_empty() {
829            panic!("Cannot get symm from empty PiecewiseLegendrePolyVector");
830        }
831        self.polyvec.iter().map(|poly| poly.symm).collect()
832    }
833
834    /// Get data as 3D tensor: [segment][degree][polynomial]
835    pub fn get_data(&self) -> mdarray::DTensor<f64, 3> {
836        if self.polyvec.is_empty() {
837            panic!("Cannot get data from empty PiecewiseLegendrePolyVector");
838        }
839
840        let nsegments = self.polyvec[0].data.shape().1;
841        let polyorder = self.polyvec[0].polyorder;
842        let npolys = self.polyvec.len();
843
844        let mut data = mdarray::DTensor::<f64, 3>::from_elem([nsegments, polyorder, npolys], 0.0);
845
846        for (poly_idx, poly) in self.polyvec.iter().enumerate() {
847            for segment in 0..nsegments {
848                for degree in 0..polyorder {
849                    data[[segment, degree, poly_idx]] = poly.data[[degree, segment]];
850                }
851            }
852        }
853
854        data
855    }
856
857    /// Find roots of all polynomials
858    pub fn roots(&self, tolerance: Option<f64>) -> Vec<f64> {
859        if self.polyvec.is_empty() {
860            panic!("Cannot get roots from empty PiecewiseLegendrePolyVector");
861        }
862        const DEFAULT_TOLERANCE: f64 = 1e-10;
863        let tolerance = tolerance.unwrap_or(DEFAULT_TOLERANCE);
864        let mut all_roots = Vec::new();
865
866        for poly in &self.polyvec {
867            let poly_roots = poly.roots();
868            for root in poly_roots {
869                all_roots.push(root);
870            }
871        }
872
873        // Sort in descending order and remove duplicates (like C++ implementation)
874        {
875            all_roots.sort_by(|a, b| b.partial_cmp(a).unwrap());
876            all_roots.dedup_by(|a, b| (*a - *b).abs() < tolerance);
877        }
878        all_roots
879    }
880
881    /// Get reference to last polynomial
882    ///
883    /// C++ equivalent: u.polyvec.back()
884    pub fn last(&self) -> &PiecewiseLegendrePoly {
885        self.polyvec
886            .last()
887            .expect("Cannot get last from empty PiecewiseLegendrePolyVector")
888    }
889
890    /// Get the number of roots
891    pub fn nroots(&self, tolerance: Option<f64>) -> usize {
892        if self.polyvec.is_empty() {
893            panic!("Cannot get nroots from empty PiecewiseLegendrePolyVector");
894        }
895        self.roots(tolerance).len()
896    }
897}
898
899impl std::ops::Index<usize> for PiecewiseLegendrePolyVector {
900    type Output = PiecewiseLegendrePoly;
901
902    fn index(&self, index: usize) -> &Self::Output {
903        &self.polyvec[index]
904    }
905}
906
907/// Get default sampling points in [-1, 1]
908///
909/// C++ implementation: libsparseir/include/sparseir/basis.hpp:287-310
910///
911/// For orthogonal polynomials (the high-T limit of IR), we know that the
912/// ideal sampling points for a basis of size L are the roots of the L-th
913/// polynomial. We empirically find that these stay good sampling points
914/// for our kernels (probably because the kernels are totally positive).
915///
916/// If we do not have enough polynomials in the basis, we approximate the
917/// roots of the L'th polynomial by the extrema of the last basis function,
918/// which is sensible due to the strong interleaving property of these
919/// functions' roots.
920pub fn default_sampling_points(u: &PiecewiseLegendrePolyVector, l: usize) -> Vec<f64> {
921    // C++: if (u.xmin() != -1.0 || u.xmax() != 1.0)
922    //          throw std::runtime_error("Expecting unscaled functions here.");
923    if (u.xmin() - (-1.0)).abs() > 1e-10 || (u.xmax() - 1.0).abs() > 1e-10 {
924        panic!("Expecting unscaled functions here.");
925    }
926
927    let x0 = if l < u.polyvec.len() {
928        // C++: return u.polyvec[L].roots();
929        u[l].roots()
930    } else {
931        // C++: PiecewiseLegendrePoly poly = u.polyvec.back();
932        //      Eigen::VectorXd maxima = poly.deriv().roots();
933        let poly = u.last();
934        let poly_deriv = poly.deriv(1);
935        let maxima = poly_deriv.roots();
936
937        // C++: double left = (maxima[0] + poly.xmin) / 2.0;
938        let left = (maxima[0] + poly.xmin) / 2.0;
939
940        // C++: double right = (maxima[maxima.size() - 1] + poly.xmax) / 2.0;
941        let right = (maxima[maxima.len() - 1] + poly.xmax) / 2.0;
942
943        // C++: Eigen::VectorXd x0(maxima.size() + 2);
944        //      x0[0] = left;
945        //      x0.segment(1, maxima.size()) = maxima;
946        //      x0[x0.size() - 1] = right;
947        let mut x0_vec = Vec::with_capacity(maxima.len() + 2);
948        x0_vec.push(left);
949        x0_vec.extend_from_slice(&maxima);
950        x0_vec.push(right);
951        x0_vec
952    };
953
954    // C++: if (x0.size() != L) { warning }
955    if x0.len() != l {
956        eprintln!(
957            "Warning: Expecting to get {} sampling points for corresponding basis function, \
958             instead got {}. This may happen if not enough precision is left in the polynomial.",
959            l,
960            x0.len()
961        );
962    }
963
964    x0
965}
966
967// IndexMut implementation removed - PiecewiseLegendrePolyVector is designed to be immutable
968// If modification is needed, create a new instance instead
969
970// Note: FnOnce implementation removed due to experimental nature
971// Use evaluate_at() and evaluate_at_many() methods directly
972
973#[cfg(test)]
974#[path = "poly_tests.rs"]
975mod poly_tests;