1use crate::advanced::rbf::RBFKernel;
40use crate::cache::CacheConfig;
41use crate::error::{InterpolateError, InterpolateResult};
42use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
43use scirs2_core::numeric::{Float, FromPrimitive};
44use std::fmt::{Debug, Display};
45
46#[derive(Debug, Clone)]
48pub struct CacheOptimizedConfig {
49 pub block_size: usize,
51 pub enable_prefetching: bool,
53 pub numa_aware: bool,
55 pub num_threads: Option<usize>,
57 pub cache_config: CacheConfig,
59 pub cache_sizes: CacheSizes,
61}
62
63#[derive(Debug, Clone)]
65pub struct CacheSizes {
66 pub l1_cache_kb: usize,
68 pub l2_cache_kb: usize,
70 pub l3_cache_kb: usize,
72}
73
74impl Default for CacheSizes {
75 fn default() -> Self {
76 Self {
77 l1_cache_kb: 32, l2_cache_kb: 256, l3_cache_kb: 8192, }
81 }
82}
83
84impl Default for CacheOptimizedConfig {
85 fn default() -> Self {
86 Self {
87 block_size: 64, enable_prefetching: true,
89 numa_aware: false,
90 num_threads: None, cache_config: CacheConfig::default(),
92 cache_sizes: CacheSizes::default(),
93 }
94 }
95}
96
97#[derive(Debug)]
99pub struct CacheAwareRBF<F>
100where
101 F: Float + FromPrimitive + Debug + Display + Send + Sync + 'static,
102{
103 points: Array2<F>,
105 values: Array1<F>,
107 kernel: RBFKernel,
109 epsilon: F,
111 coefficients: Option<Array1<F>>,
113 _config: CacheOptimizedConfig,
115 stats: CacheOptimizedStats,
117}
118
119#[derive(Debug, Default)]
121pub struct CacheOptimizedStats {
122 pub evaluations: usize,
124 pub blocked_ops_time_ns: u64,
126 pub blocks_processed: usize,
128 pub estimated_cache_miss_rate: f64,
130 pub memory_bandwidth_utilization: f64,
132}
133
134impl<F> CacheAwareRBF<F>
135where
136 F: Float + FromPrimitive + Debug + Display + Send + Sync + 'static,
137{
138 pub fn new(
150 points: Array2<F>,
151 values: Array1<F>,
152 config: CacheOptimizedConfig,
153 ) -> InterpolateResult<Self> {
154 if points.nrows() != values.len() {
155 return Err(InterpolateError::invalid_input(
156 "Number of points must match number of values".to_string(),
157 ));
158 }
159
160 let kernel = RBFKernel::Gaussian; let epsilon = F::one(); Ok(Self {
164 points,
165 values,
166 kernel,
167 epsilon,
168 coefficients: None,
169 _config: config,
170 stats: CacheOptimizedStats::default(),
171 })
172 }
173
174 pub fn with_kernel(mut self, kernel: RBFKernel) -> Self {
176 self.kernel = kernel;
177 self
178 }
179
180 pub fn with_epsilon(mut self, epsilon: F) -> Self {
182 self.epsilon = epsilon;
183 self
184 }
185
186 pub fn precompute_coefficients(&mut self) -> InterpolateResult<()> {
191 let start_time = std::time::Instant::now();
192
193 let n_points = self.points.nrows();
194 let block_size = self._config.block_size.min(n_points);
195
196 let distance_matrix = self.compute_blocked_distance_matrix()?;
198
199 let kernel_matrix = self.apply_kernel_blocked(&distance_matrix)?;
201
202 let coefficients = self.solve_rbf_system_optimized(&kernel_matrix)?;
204
205 self.coefficients = Some(coefficients);
206 self.stats.blocks_processed += n_points.div_ceil(block_size);
207 self.stats.blocked_ops_time_ns += start_time.elapsed().as_nanos() as u64;
208
209 Ok(())
210 }
211
212 fn compute_blocked_distance_matrix(&self) -> InterpolateResult<Array2<F>> {
214 let n_points = self.points.nrows();
215 let n_dims = self.points.ncols();
216 let block_size = self._config.block_size;
217
218 let mut distance_matrix = Array2::zeros((n_points, n_points));
219
220 for i_block in (0..n_points).step_by(block_size) {
222 let i_end = (i_block + block_size).min(n_points);
223
224 for j_block in (0..n_points).step_by(block_size) {
225 let j_end = (j_block + block_size).min(n_points);
226
227 for i in i_block..i_end {
229 for j in j_block..j_end {
230 if i <= j {
231 let dist = self.compute_distance_optimized(i, j, n_dims);
232 distance_matrix[[i, j]] = dist;
233 distance_matrix[[j, i]] = dist; }
235 }
236 }
237 }
238 }
239
240 Ok(distance_matrix)
241 }
242
243 fn compute_distance_optimized(&self, i: usize, j: usize, ndims: usize) -> F {
245 let mut sum_sq = F::zero();
246
247 let chunk_size = 4; for dim_chunk in (0..ndims).step_by(chunk_size) {
251 let end_dim = (dim_chunk + chunk_size).min(ndims);
252
253 for dim in dim_chunk..end_dim {
254 let diff = self.points[[i, dim]] - self.points[[j, dim]];
255 sum_sq = sum_sq + diff * diff;
256 }
257 }
258
259 sum_sq.sqrt()
260 }
261
262 fn apply_kernel_blocked(&self, distance_matrix: &Array2<F>) -> InterpolateResult<Array2<F>> {
264 let n_points = distance_matrix.nrows();
265 let block_size = self._config.block_size;
266 let mut kernel_matrix = Array2::zeros((n_points, n_points));
267
268 for i_block in (0..n_points).step_by(block_size) {
269 let i_end = (i_block + block_size).min(n_points);
270
271 for j_block in (0..n_points).step_by(block_size) {
272 let j_end = (j_block + block_size).min(n_points);
273
274 for i in i_block..i_end {
276 for j in j_block..j_end {
277 let dist = distance_matrix[[i, j]];
278 kernel_matrix[[i, j]] = self.apply_kernel_function(dist);
279 }
280 }
281 }
282 }
283
284 Ok(kernel_matrix)
285 }
286
287 fn apply_kernel_function(&self, distance: F) -> F {
289 match self.kernel {
290 RBFKernel::Gaussian => {
291 let arg = -(distance * distance) / (self.epsilon * self.epsilon);
292 arg.exp()
293 }
294 RBFKernel::Multiquadric => (distance * distance + self.epsilon * self.epsilon).sqrt(),
295 RBFKernel::InverseMultiquadric => {
296 F::one() / (distance * distance + self.epsilon * self.epsilon).sqrt()
297 }
298 RBFKernel::Linear => distance,
299 RBFKernel::Cubic => distance * distance * distance,
300 RBFKernel::ThinPlateSpline => {
301 if distance == F::zero() {
302 F::zero()
303 } else {
304 distance * distance * distance.ln()
305 }
306 }
307 RBFKernel::Quintic => {
308 let r2 = distance * distance;
309 distance * r2 * r2 }
311 }
312 }
313
314 fn solve_rbf_system_optimized(
316 &self,
317 kernel_matrix: &Array2<F>,
318 ) -> InterpolateResult<Array1<F>> {
319 let n = kernel_matrix.nrows();
324 let mut augmented = Array2::zeros((n, n + 1));
325
326 for i in 0..n {
328 for j in 0..n {
329 augmented[[i, j]] = kernel_matrix[[i, j]];
330 }
331 augmented[[i, n]] = self.values[i];
332 }
333
334 for k in 0..n {
336 let mut max_row = k;
338 for i in (k + 1)..n {
339 if augmented[[i, k]].abs() > augmented[[max_row, k]].abs() {
340 max_row = i;
341 }
342 }
343
344 if max_row != k {
346 for j in 0..=n {
347 let temp = augmented[[k, j]];
348 augmented[[k, j]] = augmented[[max_row, j]];
349 augmented[[max_row, j]] = temp;
350 }
351 }
352
353 for i in (k + 1)..n {
355 if augmented[[k, k]] != F::zero() {
356 let factor = augmented[[i, k]] / augmented[[k, k]];
357 for j in k..=n {
358 augmented[[i, j]] = augmented[[i, j]] - factor * augmented[[k, j]];
359 }
360 }
361 }
362 }
363
364 let mut solution = Array1::zeros(n);
366 for i in (0..n).rev() {
367 let mut sum = F::zero();
368 for j in (i + 1)..n {
369 sum = sum + augmented[[i, j]] * solution[j];
370 }
371 if augmented[[i, i]] != F::zero() {
372 solution[i] = (augmented[[i, n]] - sum) / augmented[[i, i]];
373 }
374 }
375
376 Ok(solution)
377 }
378
379 pub fn evaluate_cache_optimized(
389 &mut self,
390 query_points: &ArrayView2<F>,
391 ) -> InterpolateResult<Array1<F>> {
392 if self.coefficients.is_none() {
393 self.precompute_coefficients()?;
394 }
395
396 let coefficients = self.coefficients.as_ref().unwrap();
397 let n_queries = query_points.nrows();
398 let n_points = self.points.nrows();
399 let block_size = self._config.block_size;
400
401 let mut results = Array1::zeros(n_queries);
402
403 for query_block in (0..n_queries).step_by(block_size) {
405 let query_end = (query_block + block_size).min(n_queries);
406
407 for query_idx in query_block..query_end {
408 let query = query_points.row(query_idx);
409 let mut value = F::zero();
410
411 for point_block in (0..n_points).step_by(block_size) {
413 let point_end = (point_block + block_size).min(n_points);
414
415 for point_idx in point_block..point_end {
416 let point = self.points.row(point_idx);
417 let distance = self.compute_query_distance(&query, &point);
418 let kernel_value = self.apply_kernel_function(distance);
419 value = value + coefficients[point_idx] * kernel_value;
420 }
421 }
422
423 results[query_idx] = value;
424 }
425 }
426
427 self.stats.evaluations += n_queries;
428 Ok(results)
429 }
430
431 fn compute_query_distance(&self, query: &ArrayView1<F>, point: &ArrayView1<F>) -> F {
433 let mut sum_sq = F::zero();
434 for i in 0..query.len() {
435 let diff = query[i] - point[i];
436 sum_sq = sum_sq + diff * diff;
437 }
438 sum_sq.sqrt()
439 }
440
441 pub fn stats(&self) -> &CacheOptimizedStats {
443 &self.stats
444 }
445
446 pub fn reset_stats(&mut self) {
448 self.stats = CacheOptimizedStats::default();
449 }
450}
451
452#[derive(Debug)]
454pub struct CacheAwareBSpline<F>
455where
456 F: Float + FromPrimitive + Debug + Display + Copy + 'static,
457{
458 knots: Array1<F>,
460 coefficients: Array1<F>,
462 degree: usize,
464 config: CacheOptimizedConfig,
466 #[allow(dead_code)]
468 basis_cache: std::collections::HashMap<u64, Vec<F>>,
469}
470
471impl<F> CacheAwareBSpline<F>
472where
473 F: Float + FromPrimitive + Debug + Display + Copy + 'static,
474{
475 pub fn new(
477 knots: Array1<F>,
478 coefficients: Array1<F>,
479 degree: usize,
480 config: CacheOptimizedConfig,
481 ) -> InterpolateResult<Self> {
482 if knots.len() < coefficients.len() + degree + 1 {
483 return Err(InterpolateError::invalid_input(
484 "Invalid knot vector length for given coefficients and degree".to_string(),
485 ));
486 }
487
488 Ok(Self {
489 knots,
490 coefficients,
491 degree,
492 config,
493 basis_cache: std::collections::HashMap::new(),
494 })
495 }
496
497 pub fn evaluate_batch_cache_optimized(
499 &mut self,
500 x_values: &ArrayView1<F>,
501 ) -> InterpolateResult<Array1<F>> {
502 let n_points = x_values.len();
503 let mut results = Array1::zeros(n_points);
504 let block_size = self.config.block_size;
505
506 for block_start in (0..n_points).step_by(block_size) {
508 let block_end = (block_start + block_size).min(n_points);
509
510 for i in block_start..block_end {
511 let x = x_values[i];
512 results[i] = self.evaluate_single_optimized(x)?;
513 }
514 }
515
516 Ok(results)
517 }
518
519 fn evaluate_single_optimized(&mut self, x: F) -> InterpolateResult<F> {
521 let span = self.find_knot_span(x);
523
524 let basis = self.compute_basis_functions(x, span);
526
527 let mut result = F::zero();
529 for (i, &basis_val) in basis.iter().enumerate().take(self.degree + 1) {
530 let coeff_idx = span - self.degree + i;
531 if coeff_idx < self.coefficients.len() {
532 result = result + self.coefficients[coeff_idx] * basis_val;
533 }
534 }
535
536 Ok(result)
537 }
538
539 fn find_knot_span(&self, x: F) -> usize {
541 let n = self.coefficients.len();
542
543 if x >= self.knots[n] {
544 return n - 1;
545 }
546 if x <= self.knots[self.degree] {
547 return self.degree;
548 }
549
550 let mut low = self.degree;
552 let mut high = n;
553 let mut mid = (low + high) / 2;
554
555 while x < self.knots[mid] || x >= self.knots[mid + 1] {
556 if x < self.knots[mid] {
557 high = mid;
558 } else {
559 low = mid;
560 }
561 mid = (low + high) / 2;
562 }
563
564 mid
565 }
566
567 fn compute_basis_functions(&self, x: F, span: usize) -> Vec<F> {
569 let mut basis = vec![F::zero(); self.degree + 1];
570 basis[0] = F::one();
571
572 for j in 1..=self.degree {
573 let mut saved = F::zero();
574 #[allow(clippy::needless_range_loop)]
575 for r in 0..j {
576 let temp = basis[r];
577 let alpha_1 = if span + 1 + r >= j && span + 1 + r < self.knots.len() {
578 let denom = self.knots[span + 1 + r] - self.knots[span + 1 + r - j];
579 if denom != F::zero() {
580 (x - self.knots[span + 1 + r - j]) / denom
581 } else {
582 F::zero()
583 }
584 } else {
585 F::zero()
586 };
587
588 basis[r] = saved + (F::one() - alpha_1) * temp;
589 saved = alpha_1 * temp;
590 }
591 basis[j] = saved;
592 }
593
594 basis
595 }
596}
597
598#[allow(dead_code)]
611pub fn make_cache_aware_rbf<F>(
612 points: Array2<F>,
613 values: Array1<F>,
614 kernel: RBFKernel,
615 epsilon: F,
616) -> InterpolateResult<CacheAwareRBF<F>>
617where
618 F: Float + FromPrimitive + Debug + Display + Send + Sync + 'static,
619{
620 let config = CacheOptimizedConfig::default();
621
622 let mut rbf = CacheAwareRBF::new(points, values, config)?
623 .with_kernel(kernel)
624 .with_epsilon(epsilon);
625
626 rbf.precompute_coefficients()?;
627 Ok(rbf)
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633 use scirs2_core::ndarray::array;
634
635 #[test]
636 fn test_cache_aware_rbf_creation() {
637 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
638 let values = array![0.0, 1.0, 1.0, 2.0];
639 let config = CacheOptimizedConfig::default();
640
641 let rbf = CacheAwareRBF::new(points, values, config);
642 assert!(rbf.is_ok());
643 }
644
645 #[test]
646 fn test_cache_aware_bspline_creation() {
647 let knots = array![0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 3.0, 3.0];
648 let coefficients = array![1.0, 2.0, 3.0, 4.0, 5.0];
649 let config = CacheOptimizedConfig::default();
650
651 let spline = CacheAwareBSpline::new(knots, coefficients, 2, config);
652 assert!(spline.is_ok());
653 }
654
655 #[test]
656 fn test_cache_optimized_config_defaults() {
657 let config = CacheOptimizedConfig::default();
658
659 assert_eq!(config.block_size, 64);
660 assert!(config.enable_prefetching);
661 assert!(!config.numa_aware);
662 assert!(config.num_threads.is_none());
663 }
664
665 #[test]
666 fn test_cache_sizes_defaults() {
667 let cache_sizes = CacheSizes::default();
668
669 assert_eq!(cache_sizes.l1_cache_kb, 32);
670 assert_eq!(cache_sizes.l2_cache_kb, 256);
671 assert_eq!(cache_sizes.l3_cache_kb, 8192);
672 }
673}