1#[derive(Debug, Clone)]
8pub struct PiecewiseLegendrePoly {
9 pub polyorder: usize,
11 pub xmin: f64,
13 pub xmax: f64,
15 pub knots: Vec<f64>,
17 pub delta_x: Vec<f64>,
19 pub data: mdarray::DTensor<f64, 2>,
21 pub symm: i32,
23 pub l: i32,
25 pub xm: Vec<f64>,
27 pub inv_xs: Vec<f64>,
29 pub norms: Vec<f64>,
31}
32
33impl PiecewiseLegendrePoly {
34 pub fn new(
36 data: mdarray::DTensor<f64, 2>,
37 knots: Vec<f64>,
38 l: i32,
39 delta_x: Option<Vec<f64>>,
40 symm: i32,
41 ) -> Self {
42 let polyorder = data.shape().0;
43 let nsegments = data.shape().1;
44
45 if knots.len() != nsegments + 1 {
46 panic!(
47 "Invalid knots array: expected {} knots, got {}",
48 nsegments + 1,
49 knots.len()
50 );
51 }
52
53 for i in 1..knots.len() {
55 if knots[i] <= knots[i - 1] {
56 panic!("Knots must be monotonically increasing");
57 }
58 }
59
60 let delta_x =
62 delta_x.unwrap_or_else(|| (1..knots.len()).map(|i| knots[i] - knots[i - 1]).collect());
63
64 for i in 0..delta_x.len() {
66 let expected = knots[i + 1] - knots[i];
67 if (delta_x[i] - expected).abs() > 1e-10 {
68 panic!("delta_x must match knots");
69 }
70 }
71
72 let xm: Vec<f64> = (0..nsegments)
74 .map(|i| 0.5 * (knots[i] + knots[i + 1]))
75 .collect();
76
77 let inv_xs: Vec<f64> = delta_x.iter().map(|&dx| 2.0 / dx).collect();
79
80 let norms: Vec<f64> = inv_xs.iter().map(|&inv_x| inv_x.sqrt()).collect();
82
83 Self {
84 polyorder,
85 xmin: knots[0],
86 xmax: knots[knots.len() - 1],
87 knots,
88 delta_x,
89 data,
90 symm,
91 l,
92 xm,
93 inv_xs,
94 norms,
95 }
96 }
97
98 pub fn with_data(&self, new_data: mdarray::DTensor<f64, 2>) -> Self {
100 Self {
101 data: new_data,
102 ..self.clone()
103 }
104 }
105
106 pub fn symm(&self) -> i32 {
108 self.symm
109 }
110
111 pub fn with_data_and_symmetry(
113 &self,
114 new_data: mdarray::DTensor<f64, 2>,
115 new_symm: i32,
116 ) -> Self {
117 Self {
118 data: new_data,
119 symm: new_symm,
120 ..self.clone()
121 }
122 }
123
124 pub fn rescale_domain(
139 &self,
140 new_knots: Vec<f64>,
141 new_delta_x: Option<Vec<f64>>,
142 new_symm: Option<i32>,
143 ) -> Self {
144 Self::new(
145 self.data.clone(),
146 new_knots,
147 self.l,
148 new_delta_x,
149 new_symm.unwrap_or(self.symm),
150 )
151 }
152
153 pub fn scale_data(&self, factor: f64) -> Self {
166 Self::with_data(
167 self,
168 mdarray::DTensor::<f64, 2>::from_fn(*self.data.shape(), |idx| self.data[idx] * factor),
169 )
170 }
171
172 pub fn evaluate(&self, x: f64) -> f64 {
174 let (i, x_tilde) = self.split(x);
175 let coeffs: Vec<f64> = (0..self.data.shape().0)
177 .map(|row| self.data[[row, i]])
178 .collect();
179 let value = self.evaluate_legendre_polynomial(x_tilde, &coeffs);
180 value * self.norms[i]
181 }
182
183 pub fn evaluate_many(&self, xs: &[f64]) -> Vec<f64> {
185 xs.iter().map(|&x| self.evaluate(x)).collect()
186 }
187
188 pub fn split(&self, x: f64) -> (usize, f64) {
190 if x < self.xmin || x > self.xmax {
191 panic!("x = {} is outside domain [{}, {}]", x, self.xmin, self.xmax);
192 }
193
194 for i in 0..self.knots.len() - 1 {
196 if x >= self.knots[i] && x <= self.knots[i + 1] {
197 let x_tilde = 2.0 * (x - self.xm[i]) / self.delta_x[i];
199 return (i, x_tilde);
200 }
201 }
202
203 let last_idx = self.knots.len() - 2;
205 let x_tilde = 2.0 * (x - self.xm[last_idx]) / self.delta_x[last_idx];
206 (last_idx, x_tilde)
207 }
208
209 pub fn evaluate_legendre_polynomial(&self, x: f64, coeffs: &[f64]) -> f64 {
211 if coeffs.is_empty() {
212 return 0.0;
213 }
214
215 let mut result = 0.0;
216 let mut p_prev = 1.0; let mut p_curr = x; if !coeffs.is_empty() {
221 result += coeffs[0] * p_prev;
222 }
223 if coeffs.len() > 1 {
224 result += coeffs[1] * p_curr;
225 }
226
227 for n in 1..coeffs.len() - 1 {
229 let p_next =
230 ((2.0 * (n as f64) + 1.0) * x * p_curr - (n as f64) * p_prev) / ((n + 1) as f64);
231 result += coeffs[n + 1] * p_next;
232 p_prev = p_curr;
233 p_curr = p_next;
234 }
235
236 result
237 }
238
239 pub fn deriv(&self, n: usize) -> Self {
241 if n == 0 {
242 return self.clone();
243 }
244
245 let mut ddata = self.data.clone();
247 for _ in 0..n {
248 ddata = self.compute_derivative_coefficients(&ddata);
249 }
250
251 let ddata_shape = *ddata.shape();
253 for i in 0..ddata_shape.1 {
254 let inv_x_power = self.inv_xs[i].powi(n as i32);
255 for j in 0..ddata_shape.0 {
256 ddata[[j, i]] *= inv_x_power;
257 }
258 }
259
260 let new_symm = if n % 2 == 0 { self.symm } else { -self.symm };
262
263 Self {
264 data: ddata,
265 symm: new_symm,
266 ..self.clone()
267 }
268 }
269
270 fn compute_derivative_coefficients(
272 &self,
273 coeffs: &mdarray::DTensor<f64, 2>,
274 ) -> mdarray::DTensor<f64, 2> {
275 let mut c = coeffs.clone();
276 let c_shape = *c.shape();
277 let mut n = c_shape.0;
278
279 if n <= 1 {
281 return mdarray::DTensor::<f64, 2>::from_elem([1, c.shape().1], 0.0);
282 }
283
284 n -= 1;
285 let mut der = mdarray::DTensor::<f64, 2>::from_elem([n, c.shape().1], 0.0);
286
287 for j in (2..=n).rev() {
289 for col in 0..c_shape.1 {
291 der[[j - 1, col]] = (2.0 * (j as f64) - 1.0) * c[[j, col]];
292 }
293 for col in 0..c_shape.1 {
295 c[[j - 2, col]] += c[[j, col]];
296 }
297 }
298
299 if n > 1 {
301 for col in 0..c_shape.1 {
302 der[[1, col]] = 3.0 * c[[2, col]];
303 }
304 }
305
306 for col in 0..c_shape.1 {
308 der[[0, col]] = c[[1, col]];
309 }
310
311 der
312 }
313
314 pub fn derivs(&self, x: f64) -> Vec<f64> {
316 let mut results = Vec::new();
317
318 for n in 0..self.polyorder {
320 let deriv_poly = self.deriv(n);
321 results.push(deriv_poly.evaluate(x));
322 }
323
324 results
325 }
326
327 pub fn overlap<F>(&self, f: F) -> f64
329 where
330 F: Fn(f64) -> f64,
331 {
332 let mut integral = 0.0;
333
334 for i in 0..self.knots.len() - 1 {
335 let segment_integral =
336 self.gauss_legendre_quadrature(self.knots[i], self.knots[i + 1], |x| {
337 self.evaluate(x) * f(x)
338 });
339 integral += segment_integral;
340 }
341
342 integral
343 }
344
345 fn gauss_legendre_quadrature<F>(&self, a: f64, b: f64, f: F) -> f64
347 where
348 F: Fn(f64) -> f64,
349 {
350 const XG: [f64; 5] = [
352 -0.906179845938664,
353 -0.538469310105683,
354 0.0,
355 0.538469310105683,
356 0.906179845938664,
357 ];
358 const WG: [f64; 5] = [
359 0.236926885056189,
360 0.478628670499366,
361 0.568888888888889,
362 0.478628670499366,
363 0.236926885056189,
364 ];
365
366 let c1 = (b - a) / 2.0;
367 let c2 = (b + a) / 2.0;
368
369 let mut integral = 0.0;
370 for j in 0..5 {
371 let x = c1 * XG[j] + c2;
372 integral += WG[j] * f(x);
373 }
374
375 integral * c1
376 }
377
378 pub fn roots(&self) -> Vec<f64> {
380 let refined_grid = self.refine_grid(&self.knots, 4);
383
384 self.find_all_roots(&refined_grid)
386 }
387
388 fn refine_grid(&self, grid: &[f64], alpha: usize) -> Vec<f64> {
390 let mut refined = Vec::new();
391
392 for i in 0..grid.len() - 1 {
393 let start = grid[i];
394 let step = (grid[i + 1] - grid[i]) / (alpha as f64);
395 for j in 0..alpha {
396 refined.push(start + (j as f64) * step);
397 }
398 }
399 refined.push(grid[grid.len() - 1]);
400 refined
401 }
402
403 fn find_all_roots(&self, xgrid: &[f64]) -> Vec<f64> {
405 if xgrid.is_empty() {
406 return Vec::new();
407 }
408
409 let fx: Vec<f64> = xgrid.iter().map(|&x| self.evaluate(x)).collect();
411
412 let mut x_hit = Vec::new();
414 for i in 0..fx.len() {
415 if fx[i] == 0.0 {
416 x_hit.push(xgrid[i]);
417 }
418 }
419
420 let mut sign_change = Vec::new();
422 for i in 0..fx.len() - 1 {
423 let has_sign_change = fx[i].signum() != fx[i + 1].signum();
424 let not_hit = fx[i] != 0.0 && fx[i + 1] != 0.0;
425 sign_change.push(has_sign_change && not_hit);
426 }
427
428 if sign_change.iter().all(|&sc| !sc) {
430 x_hit.sort_by(|a, b| a.partial_cmp(b).unwrap());
431 return x_hit;
432 }
433
434 let mut a_intervals = Vec::new();
436 let mut b_intervals = Vec::new();
437 let mut fa_values = Vec::new();
438
439 for i in 0..sign_change.len() {
440 if sign_change[i] {
441 a_intervals.push(xgrid[i]);
442 b_intervals.push(xgrid[i + 1]);
443 fa_values.push(fx[i]);
444 }
445 }
446
447 let max_elm = xgrid.iter().map(|&x| x.abs()).fold(0.0, f64::max);
449 let epsilon_x = f64::EPSILON * max_elm;
450
451 for i in 0..a_intervals.len() {
453 let root = self.bisect(a_intervals[i], b_intervals[i], fa_values[i], epsilon_x);
454 x_hit.push(root);
455 }
456
457 x_hit.sort_by(|a, b| a.partial_cmp(b).unwrap());
459 x_hit
460 }
461
462 fn bisect(&self, a: f64, b: f64, fa: f64, eps: f64) -> f64 {
464 let mut a = a;
465 let mut b = b;
466 let mut fa = fa;
467
468 loop {
469 let mid = (a + b) / 2.0;
470 if self.close_enough(a, mid, eps) {
471 return mid;
472 }
473
474 let fmid = self.evaluate(mid);
475 if fa.signum() != fmid.signum() {
476 b = mid;
477 } else {
478 a = mid;
479 fa = fmid;
480 }
481 }
482 }
483
484 fn close_enough(&self, a: f64, b: f64, eps: f64) -> bool {
486 (a - b).abs() <= eps
487 }
488
489 pub fn get_xmin(&self) -> f64 {
491 self.xmin
492 }
493 pub fn get_xmax(&self) -> f64 {
494 self.xmax
495 }
496 pub fn get_l(&self) -> i32 {
497 self.l
498 }
499 pub fn get_domain(&self) -> (f64, f64) {
500 (self.xmin, self.xmax)
501 }
502 pub fn get_knots(&self) -> &[f64] {
503 &self.knots
504 }
505 pub fn get_delta_x(&self) -> &[f64] {
506 &self.delta_x
507 }
508 pub fn get_symm(&self) -> i32 {
509 self.symm
510 }
511 pub fn get_data(&self) -> &mdarray::DTensor<f64, 2> {
512 &self.data
513 }
514 pub fn get_norms(&self) -> &[f64] {
515 &self.norms
516 }
517 pub fn get_polyorder(&self) -> usize {
518 self.polyorder
519 }
520}
521
522#[derive(Debug, Clone)]
524pub struct PiecewiseLegendrePolyVector {
525 pub polyvec: Vec<PiecewiseLegendrePoly>,
527}
528
529impl PiecewiseLegendrePolyVector {
530 pub fn new(polyvec: Vec<PiecewiseLegendrePoly>) -> Self {
535 if polyvec.is_empty() {
536 panic!("Cannot create empty PiecewiseLegendrePolyVector");
537 }
538 Self { polyvec }
539 }
540
541 pub fn get_polys(&self) -> &[PiecewiseLegendrePoly] {
543 &self.polyvec
544 }
545
546 pub fn from_3d_data(
548 data3d: mdarray::DTensor<f64, 3>,
549 knots: Vec<f64>,
550 symm: Option<Vec<i32>>,
551 ) -> Self {
552 let npolys = data3d.shape().2;
553 let mut polyvec = Vec::with_capacity(npolys);
554
555 if let Some(ref symm_vec) = symm {
556 if symm_vec.len() != npolys {
557 panic!("Sizes of data and symm don't match");
558 }
559 }
560
561 let delta_x: Vec<f64> = (1..knots.len()).map(|i| knots[i] - knots[i - 1]).collect();
563
564 for i in 0..npolys {
565 let data3d_shape = data3d.shape();
567 let mut data =
568 mdarray::DTensor::<f64, 2>::from_elem([data3d_shape.0, data3d_shape.1], 0.0);
569 for j in 0..data3d_shape.0 {
570 for k in 0..data3d_shape.1 {
571 data[[j, k]] = data3d[[j, k, i]];
572 }
573 }
574
575 let poly = PiecewiseLegendrePoly::new(
576 data,
577 knots.clone(),
578 i as i32,
579 Some(delta_x.clone()),
580 symm.as_ref().map_or(0, |s| s[i]),
581 );
582
583 polyvec.push(poly);
584 }
585
586 Self { polyvec }
587 }
588
589 pub fn size(&self) -> usize {
591 self.polyvec.len()
592 }
593
594 pub fn rescale_domain(
609 &self,
610 new_knots: Vec<f64>,
611 new_delta_x: Option<Vec<f64>>,
612 new_symm: Option<Vec<i32>>,
613 ) -> Self {
614 let polyvec = self
615 .polyvec
616 .iter()
617 .enumerate()
618 .map(|(i, poly)| {
619 let symm = new_symm.as_ref().map(|s| s[i]);
620 poly.rescale_domain(new_knots.clone(), new_delta_x.clone(), symm)
621 })
622 .collect();
623
624 Self { polyvec }
625 }
626
627 pub fn scale_data(&self, factor: f64) -> Self {
639 let polyvec = self
640 .polyvec
641 .iter()
642 .map(|poly| poly.scale_data(factor))
643 .collect();
644
645 Self { polyvec }
646 }
647
648 pub fn get(&self, index: usize) -> Option<&PiecewiseLegendrePoly> {
650 self.polyvec.get(index)
651 }
652
653 #[deprecated(
655 note = "PiecewiseLegendrePolyVector is designed to be immutable. Use get() and create new instances for modifications."
656 )]
657 pub fn get_mut(&mut self, index: usize) -> Option<&mut PiecewiseLegendrePoly> {
658 self.polyvec.get_mut(index)
659 }
660
661 pub fn slice_single(&self, index: usize) -> Option<Self> {
663 self.polyvec.get(index).map(|poly| Self {
664 polyvec: vec![poly.clone()],
665 })
666 }
667
668 pub fn slice_multi(&self, indices: &[usize]) -> Self {
670 for &idx in indices {
672 if idx >= self.polyvec.len() {
673 panic!("Index {} out of range", idx);
674 }
675 }
676
677 {
679 let mut unique_indices = indices.to_vec();
680 unique_indices.sort();
681 unique_indices.dedup();
682 if unique_indices.len() != indices.len() {
683 panic!("Duplicate indices not allowed");
684 }
685 }
686
687 let new_polyvec: Vec<_> = indices
688 .iter()
689 .map(|&idx| self.polyvec[idx].clone())
690 .collect();
691
692 Self {
693 polyvec: new_polyvec,
694 }
695 }
696
697 pub fn evaluate_at(&self, x: f64) -> Vec<f64> {
699 self.polyvec.iter().map(|poly| poly.evaluate(x)).collect()
700 }
701
702 pub fn evaluate_at_many(&self, xs: &[f64]) -> mdarray::DTensor<f64, 2> {
704 let n_funcs = self.polyvec.len();
705 let n_points = xs.len();
706 let mut results = mdarray::DTensor::<f64, 2>::from_elem([n_funcs, n_points], 0.0);
707
708 for (i, poly) in self.polyvec.iter().enumerate() {
709 for (j, &x) in xs.iter().enumerate() {
710 results[[i, j]] = poly.evaluate(x);
711 }
712 }
713
714 results
715 }
716
717 pub fn xmin(&self) -> f64 {
719 if self.polyvec.is_empty() {
720 panic!("Cannot get xmin from empty PiecewiseLegendrePolyVector");
721 }
722 self.polyvec[0].xmin
723 }
724
725 pub fn xmax(&self) -> f64 {
726 if self.polyvec.is_empty() {
727 panic!("Cannot get xmax from empty PiecewiseLegendrePolyVector");
728 }
729 self.polyvec[0].xmax
730 }
731
732 pub fn get_knots(&self, tolerance: Option<f64>) -> Vec<f64> {
733 if self.polyvec.is_empty() {
734 panic!("Cannot get knots from empty PiecewiseLegendrePolyVector");
735 }
736 const DEFAULT_TOLERANCE: f64 = 1e-10;
737 let tolerance = tolerance.unwrap_or(DEFAULT_TOLERANCE);
738
739 let mut all_knots = Vec::new();
741 for poly in &self.polyvec {
742 for &knot in &poly.knots {
743 all_knots.push(knot);
744 }
745 }
746
747 {
749 all_knots.sort_by(|a, b| a.partial_cmp(b).unwrap());
750 all_knots.dedup_by(|a, b| (*a - *b).abs() < tolerance);
751 }
752 all_knots
753 }
754
755 pub fn get_delta_x(&self) -> Vec<f64> {
756 if self.polyvec.is_empty() {
757 panic!("Cannot get delta_x from empty PiecewiseLegendrePolyVector");
758 }
759 self.polyvec[0].delta_x.clone()
760 }
761
762 pub fn get_polyorder(&self) -> usize {
763 if self.polyvec.is_empty() {
764 panic!("Cannot get polyorder from empty PiecewiseLegendrePolyVector");
765 }
766 self.polyvec[0].polyorder
767 }
768
769 pub fn get_norms(&self) -> &[f64] {
770 if self.polyvec.is_empty() {
771 panic!("Cannot get norms from empty PiecewiseLegendrePolyVector");
772 }
773 &self.polyvec[0].norms
774 }
775
776 pub fn get_symm(&self) -> Vec<i32> {
777 if self.polyvec.is_empty() {
778 panic!("Cannot get symm from empty PiecewiseLegendrePolyVector");
779 }
780 self.polyvec.iter().map(|poly| poly.symm).collect()
781 }
782
783 pub fn get_data(&self) -> mdarray::DTensor<f64, 3> {
785 if self.polyvec.is_empty() {
786 panic!("Cannot get data from empty PiecewiseLegendrePolyVector");
787 }
788
789 let nsegments = self.polyvec[0].data.shape().1;
790 let polyorder = self.polyvec[0].polyorder;
791 let npolys = self.polyvec.len();
792
793 let mut data = mdarray::DTensor::<f64, 3>::from_elem([nsegments, polyorder, npolys], 0.0);
794
795 for (poly_idx, poly) in self.polyvec.iter().enumerate() {
796 for segment in 0..nsegments {
797 for degree in 0..polyorder {
798 data[[segment, degree, poly_idx]] = poly.data[[degree, segment]];
799 }
800 }
801 }
802
803 data
804 }
805
806 pub fn roots(&self, tolerance: Option<f64>) -> Vec<f64> {
808 if self.polyvec.is_empty() {
809 panic!("Cannot get roots from empty PiecewiseLegendrePolyVector");
810 }
811 const DEFAULT_TOLERANCE: f64 = 1e-10;
812 let tolerance = tolerance.unwrap_or(DEFAULT_TOLERANCE);
813 let mut all_roots = Vec::new();
814
815 for poly in &self.polyvec {
816 let poly_roots = poly.roots();
817 for root in poly_roots {
818 all_roots.push(root);
819 }
820 }
821
822 {
824 all_roots.sort_by(|a, b| b.partial_cmp(a).unwrap());
825 all_roots.dedup_by(|a, b| (*a - *b).abs() < tolerance);
826 }
827 all_roots
828 }
829
830 pub fn last(&self) -> &PiecewiseLegendrePoly {
834 self.polyvec
835 .last()
836 .expect("Cannot get last from empty PiecewiseLegendrePolyVector")
837 }
838
839 pub fn nroots(&self, tolerance: Option<f64>) -> usize {
841 if self.polyvec.is_empty() {
842 panic!("Cannot get nroots from empty PiecewiseLegendrePolyVector");
843 }
844 self.roots(tolerance).len()
845 }
846}
847
848impl std::ops::Index<usize> for PiecewiseLegendrePolyVector {
849 type Output = PiecewiseLegendrePoly;
850
851 fn index(&self, index: usize) -> &Self::Output {
852 &self.polyvec[index]
853 }
854}
855
856pub fn default_sampling_points(u: &PiecewiseLegendrePolyVector, l: usize) -> Vec<f64> {
870 if (u.xmin() - (-1.0)).abs() > 1e-10 || (u.xmax() - 1.0).abs() > 1e-10 {
873 panic!("Expecting unscaled functions here.");
874 }
875
876 let x0 = if l < u.polyvec.len() {
877 u[l].roots()
879 } else {
880 let poly = u.last();
883 let poly_deriv = poly.deriv(1);
884 let maxima = poly_deriv.roots();
885
886 let left = (maxima[0] + poly.xmin) / 2.0;
888
889 let right = (maxima[maxima.len() - 1] + poly.xmax) / 2.0;
891
892 let mut x0_vec = Vec::with_capacity(maxima.len() + 2);
897 x0_vec.push(left);
898 x0_vec.extend_from_slice(&maxima);
899 x0_vec.push(right);
900
901 x0_vec
902 };
903
904 if x0.len() != l {
906 eprintln!(
907 "Warning: Expecting to get {} sampling points for corresponding basis function, \
908 instead got {}. This may happen if not enough precision is left in the polynomial.",
909 l,
910 x0.len()
911 );
912 }
913
914 x0
915}
916
917#[cfg(test)]
924#[path = "poly_tests.rs"]
925mod poly_tests;