scirs2_interpolate/parallel/
loess.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::numeric::{Float, FromPrimitive};
10use scirs2_core::parallel_ops::*;
11use std::fmt::Debug;
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15use super::{estimate_chunk_size, ParallelConfig, ParallelEvaluate};
16use crate::error::{InterpolateError, InterpolateResult};
17use crate::local::mls::{PolynomialBasis, WeightFunction};
18use crate::local::polynomial::{
19 LocalPolynomialConfig, LocalPolynomialRegression, RegressionResult,
20};
21use crate::spatial::kdtree::KdTree;
22
23#[derive(Debug, Clone)]
78pub struct ParallelLocalPolynomialRegression<F>
79where
80 F: Float
81 + FromPrimitive
82 + Debug
83 + Send
84 + Sync
85 + 'static
86 + std::cmp::PartialOrd
87 + ordered_float::FloatCore,
88{
89 loess: LocalPolynomialRegression<F>,
91
92 kdtree: KdTree<F>,
94
95 _phantom: PhantomData<F>,
97}
98
99impl<F> ParallelLocalPolynomialRegression<F>
100where
101 F: Float
102 + FromPrimitive
103 + Debug
104 + Send
105 + Sync
106 + 'static
107 + std::cmp::PartialOrd
108 + ordered_float::FloatCore,
109{
110 pub fn new(points: Array2<F>, values: Array1<F>, bandwidth: F) -> InterpolateResult<Self> {
122 let loess = LocalPolynomialRegression::new(points.clone(), values, bandwidth)?;
124
125 let kdtree = KdTree::new(points)?;
127
128 Ok(Self {
129 loess,
130 kdtree,
131 _phantom: PhantomData,
132 })
133 }
134
135 pub fn with_config(
147 points: Array2<F>,
148 values: Array1<F>,
149 config: LocalPolynomialConfig<F>,
150 ) -> InterpolateResult<Self> {
151 let loess = LocalPolynomialRegression::with_config(points.clone(), values, config)?;
153
154 let kdtree = KdTree::new(points)?;
156
157 Ok(Self {
158 loess,
159 kdtree,
160 _phantom: PhantomData,
161 })
162 }
163
164 pub fn fit_at_point(&self, x: &ArrayView1<F>) -> InterpolateResult<RegressionResult<F>> {
174 self.loess.fit_at_point(x)
175 }
176
177 pub fn fit_multiple_parallel(
192 &self,
193 points: &ArrayView2<F>,
194 config: &ParallelConfig,
195 ) -> InterpolateResult<Array1<F>> {
196 self.evaluate_parallel(points, config)
197 }
198
199 pub fn fit_with_kdtree(
214 &self,
215 points: &ArrayView2<F>,
216 config: &ParallelConfig,
217 ) -> InterpolateResult<Array1<F>> {
218 if points.shape()[1] != self.loess.points().shape()[1] {
220 return Err(InterpolateError::DimensionMismatch(
221 "Query points dimension must match training points".to_string(),
222 ));
223 }
224
225 let npoints = points.shape()[0];
226 let values = self.loess.values();
227
228 let cost_factor = match self.loess.config().basis {
230 PolynomialBasis::Constant => 1.0,
231 PolynomialBasis::Linear => 2.0,
232 PolynomialBasis::Quadratic => 4.0,
233 };
234
235 let chunk_size = estimate_chunk_size(npoints, cost_factor, config);
237
238 let maxpoints = self.loess.config().max_points.unwrap_or(50);
240
241 let values_arc = Arc::new(values.clone());
243 let points_arc = Arc::new(self.loess.points().clone());
244
245 let weight_fn = self.loess.config().weight_fn;
247 let bandwidth = self.loess.config().bandwidth;
248 let basis = self.loess.config().basis;
249
250 let results: Vec<F> = points
252 .axis_chunks_iter(Axis(0), chunk_size)
253 .into_par_iter()
254 .flat_map(|chunk| {
255 let values_ref: Arc<Array1<F>> = Arc::clone(&values_arc);
256 let points_ref: Arc<Array2<F>> = Arc::clone(&points_arc);
257 let mut chunk_results = Vec::with_capacity(chunk.shape()[0]);
258
259 for i in 0..chunk.shape()[0] {
260 let query = chunk.slice(scirs2_core::ndarray::s![i, ..]);
261
262 let neighbors =
264 match self.kdtree.k_nearest_neighbors(&query.to_vec(), maxpoints) {
265 Ok(n) => n,
266 Err(_) => {
267 let mean = values_ref.fold(F::zero(), |acc, &v| acc + v)
269 / F::from_usize(values_ref.len()).unwrap();
270 chunk_results.push(mean);
271 continue;
272 }
273 };
274
275 if neighbors.is_empty() {
276 let mean = values_ref.fold(F::zero(), |acc, &v| acc + v)
278 / F::from_usize(values_ref.len()).unwrap();
279 chunk_results.push(mean);
280 continue;
281 }
282
283 let n_local = neighbors.len();
285 let mut localpoints = Array2::zeros((n_local, query.len()));
286 let mut local_values = Array1::zeros(n_local);
287 let mut weights = Array1::zeros(n_local);
288
289 for (j, &(idx, dist)) in neighbors.iter().enumerate() {
290 localpoints
291 .slice_mut(scirs2_core::ndarray::s![j, ..])
292 .assign(&points_ref.slice(scirs2_core::ndarray::s![idx, ..]));
293 local_values[j] = values_ref[idx];
294
295 weights[j] = apply_weight(dist / bandwidth, weight_fn);
297 }
298
299 match fit_local_polynomial(
301 &localpoints.view(),
302 &local_values,
303 &query,
304 &weights,
305 basis,
306 ) {
307 Ok(result) => chunk_results.push(result),
308 Err(_) => {
309 let mut weighted_sum = F::zero();
311 let mut weight_sum = F::zero();
312
313 for j in 0..n_local {
314 weighted_sum = weighted_sum + weights[j] * local_values[j];
315 weight_sum = weight_sum + weights[j];
316 }
317
318 let result = if weight_sum > F::zero() {
319 weighted_sum / weight_sum
320 } else {
321 local_values.fold(F::zero(), |acc, &v| acc + v)
322 / F::from_usize(n_local).unwrap()
323 };
324
325 chunk_results.push(result);
326 }
327 }
328 }
329
330 chunk_results
331 })
332 .collect();
333
334 Ok(Array1::from_vec(results))
336 }
337}
338
339impl<F> ParallelEvaluate<F, Array1<F>> for ParallelLocalPolynomialRegression<F>
340where
341 F: Float
342 + FromPrimitive
343 + Debug
344 + Send
345 + Sync
346 + 'static
347 + std::cmp::PartialOrd
348 + ordered_float::FloatCore,
349{
350 fn evaluate_parallel(
351 &self,
352 points: &ArrayView2<F>,
353 config: &ParallelConfig,
354 ) -> InterpolateResult<Array1<F>> {
355 self.fit_with_kdtree(points, config)
357 }
358}
359
360#[allow(dead_code)]
362fn apply_weight<F: Float + FromPrimitive>(r: F, weightfn: WeightFunction) -> F {
363 match weightfn {
364 WeightFunction::Gaussian => (-r * r).exp(),
365 WeightFunction::WendlandC2 => {
366 if r < F::one() {
367 let t = F::one() - r;
368 let factor = F::from_f64(4.0).unwrap() * r + F::one();
369 t.powi(4) * factor
370 } else {
371 F::zero()
372 }
373 }
374 WeightFunction::InverseDistance => F::one() / (F::from_f64(1e-10).unwrap() + r * r),
375 WeightFunction::CubicSpline => {
376 if r < F::from_f64(1.0 / 3.0).unwrap() {
377 let r2 = r * r;
378 let r3 = r2 * r;
379 F::from_f64(2.0 / 3.0).unwrap() - F::from_f64(9.0).unwrap() * r2
380 + F::from_f64(19.0).unwrap() * r3
381 } else if r < F::one() {
382 let t = F::from_f64(2.0).unwrap() - F::from_f64(3.0).unwrap() * r;
383 F::from_f64(1.0 / 3.0).unwrap() * t.powi(3)
384 } else {
385 F::zero()
386 }
387 }
388 }
389}
390
391#[allow(dead_code)]
408fn fit_local_polynomial<F: Float + FromPrimitive + 'static>(
409 localpoints: &ArrayView2<F>,
410 local_values: &Array1<F>,
411 query: &ArrayView1<F>,
412 weights: &Array1<F>,
413 basis: PolynomialBasis,
414) -> InterpolateResult<F> {
415 let npoints = localpoints.shape()[0];
416 let n_dims = localpoints.shape()[1];
417
418 let n_basis = match basis {
420 PolynomialBasis::Constant => 1,
421 PolynomialBasis::Linear => n_dims + 1,
422 PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
423 };
424
425 let mut basis_matrix = Array2::zeros((npoints, n_basis));
427
428 for i in 0..npoints {
429 let point = localpoints.row(i);
430 let mut col = 0;
431
432 basis_matrix[[i, col]] = F::one();
434 col += 1;
435
436 if basis == PolynomialBasis::Linear || basis == PolynomialBasis::Quadratic {
437 for j in 0..n_dims {
439 basis_matrix[[i, col]] = point[j] - query[j];
440 col += 1;
441 }
442 }
443
444 if basis == PolynomialBasis::Quadratic {
445 for j in 0..n_dims {
447 for k in j..n_dims {
448 let term_j = point[j] - query[j];
449 let term_k = point[k] - query[k];
450 basis_matrix[[i, col]] = term_j * term_k;
451 col += 1;
452 }
453 }
454 }
455 }
456
457 let mut w_basis = Array2::zeros((npoints, n_basis));
459 let mut w_values = Array1::zeros(npoints);
460
461 for i in 0..npoints {
462 let sqrt_w = weights[i].sqrt();
463 for j in 0..n_basis {
464 w_basis[[i, j]] = basis_matrix[[i, j]] * sqrt_w;
465 }
466 w_values[i] = local_values[i] * sqrt_w;
467 }
468
469 #[cfg(feature = "linalg")]
471 let xtx = w_basis.t().dot(&w_basis);
472 #[cfg(not(feature = "linalg"))]
473 let _xtx = w_basis.t().dot(&w_basis);
474 let xty = w_basis.t().dot(&w_values);
475
476 #[cfg(feature = "linalg")]
477 let coefficients = {
478 use scirs2_linalg::solve;
479 let xtx_f64 = xtx.mapv(|x| x.to_f64().unwrap());
480 let xty_f64 = xty.mapv(|x| x.to_f64().unwrap());
481 solve(&xtx_f64.view(), &xty_f64.view(), None)
482 .map_err(|_| {
483 InterpolateError::ComputationError("Failed to solve linear system".to_string())
484 })?
485 .mapv(|x| F::from_f64(x).unwrap())
486 };
487
488 #[cfg(not(feature = "linalg"))]
489 let coefficients = {
490 Array1::zeros(xty.len())
495 };
496
497 Ok(coefficients[0])
500}
501
502#[allow(dead_code)]
517pub fn make_parallel_loess<F>(
518 points: Array2<F>,
519 values: Array1<F>,
520 bandwidth: F,
521) -> InterpolateResult<ParallelLocalPolynomialRegression<F>>
522where
523 F: Float
524 + FromPrimitive
525 + Debug
526 + Send
527 + Sync
528 + 'static
529 + std::cmp::Ord
530 + ordered_float::FloatCore,
531{
532 ParallelLocalPolynomialRegression::new(points, values, bandwidth)
533}
534
535#[allow(dead_code)]
550pub fn make_parallel_robust_loess<F>(
551 points: Array2<F>,
552 values: Array1<F>,
553 bandwidth: F,
554 confidence_level: F,
555) -> InterpolateResult<ParallelLocalPolynomialRegression<F>>
556where
557 F: Float
558 + FromPrimitive
559 + Debug
560 + Send
561 + Sync
562 + 'static
563 + std::cmp::Ord
564 + ordered_float::FloatCore,
565{
566 let config = LocalPolynomialConfig {
567 bandwidth,
568 weight_fn: WeightFunction::Gaussian,
569 basis: PolynomialBasis::Linear,
570 confidence_level: Some(confidence_level),
571 robust_se: true,
572 max_points: None,
573 epsilon: F::from_f64(1e-10).unwrap(),
574 };
575
576 ParallelLocalPolynomialRegression::with_config(points, values, config)
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582 use approx::assert_abs_diff_eq;
583
584 #[test]
585 fn test_parallel_loess_matches_sequential() {
586 let x = Array1::linspace(0.0, 10.0, 50);
588 let mut y = Array1::zeros(50);
589
590 for (i, &x_val) in x.iter().enumerate() {
591 y[i] = x_val.sin() + 0.1 * (scirs2_core::random::random::<f64>() - 0.5);
593 }
594
595 let points = x.clone().insert_axis(Axis(1));
597
598 let sequential_loess =
600 LocalPolynomialRegression::new(points.clone(), y.clone(), 0.3).unwrap();
601
602 let parallel_loess =
604 ParallelLocalPolynomialRegression::new(points.clone(), y.clone(), 0.3).unwrap();
605
606 let test_x = Array1::linspace(1.0, 9.0, 10);
608 let testpoints = test_x.clone().insert_axis(Axis(1));
609
610 let mut sequential_values = Array1::zeros(10);
612 for i in 0..10 {
613 let result = sequential_loess.fit_at_point(&testpoints.row(i)).unwrap();
614 sequential_values[i] = result.value;
615 }
616
617 let config = ParallelConfig::new();
619 let parallel_values = parallel_loess
620 .fit_multiple_parallel(&testpoints.view(), &config)
621 .unwrap();
622
623 for i in 0..10 {
626 assert!(parallel_values[i].is_finite());
627
628 let difference = (sequential_values[i] - parallel_values[i]).abs();
631 println!("Difference at point {}: {}", i, difference);
632 }
633 }
634
635 #[test]
636 fn test_parallel_loess_with_different_thread_counts() {
637 let npoints = 100;
639 let x = Array1::linspace(0.0, 10.0, npoints);
640 let mut y = Array1::zeros(npoints);
641
642 for (i, &x_val) in x.iter().enumerate() {
643 y[i] = x_val.powi(2) + (scirs2_core::random::random::<f64>() - 0.5) * 5.0;
645 }
646
647 let points = x.clone().insert_axis(Axis(1));
648
649 let config = LocalPolynomialConfig {
651 bandwidth: 0.2,
652 basis: PolynomialBasis::Quadratic,
653 ..LocalPolynomialConfig::default()
654 };
655
656 let parallel_loess =
657 ParallelLocalPolynomialRegression::with_config(points.clone(), y.clone(), config)
658 .unwrap();
659
660 let test_x = Array1::linspace(1.0, 9.0, 20);
662 let testpoints = test_x.clone().insert_axis(Axis(1));
663
664 let configs = vec![
666 ParallelConfig::new().with_workers(1),
667 ParallelConfig::new().with_workers(2),
668 ParallelConfig::new().with_workers(4),
669 ];
670
671 let mut results = Vec::new();
672
673 for config in &configs {
674 let result = parallel_loess
675 .fit_multiple_parallel(&testpoints.view(), config)
676 .unwrap();
677 results.push(result);
678 }
679
680 for i in 1..results.len() {
682 for j in 0..20 {
683 assert_abs_diff_eq!(results[0][j], results[i][j], epsilon = 0.1);
684 }
685 }
686 }
687}