1use super::types::*;
8use crate::error::InterpolateResult;
9use scirs2_core::numeric::Float;
10use std::collections::{HashMap, VecDeque};
11use std::fmt::Debug;
12use std::time::Instant;
13
14#[derive(Debug)]
16pub struct AccuracyOptimizationEngine<F: Float + Debug> {
17 strategy: AccuracyOptimizationStrategy,
19 targets: AccuracyTargets<F>,
21 error_predictor: ErrorPredictionModel<F>,
23 optimization_history: VecDeque<AccuracyOptimizationResult>,
25}
26
27#[derive(Debug, Clone)]
29pub enum AccuracyOptimizationStrategy {
30 MaximizeAccuracy,
32 BalancedAccuracy,
34 MinimumAccuracy,
36 Adaptive,
38 Custom {
40 accuracy_weight: f64,
41 performance_weight: f64,
42 },
43}
44
45#[derive(Debug, Clone)]
47pub struct AccuracyTargets<F: Float> {
48 pub target_absolute_error: Option<F>,
50 pub target_relative_error: Option<F>,
52 pub max_acceptable_error: F,
54 pub confidence_level: F,
56}
57
58#[derive(Debug)]
60pub struct ErrorPredictionModel<F: Float> {
61 prediction_params: HashMap<String, F>,
63 error_history: VecDeque<ErrorRecord<F>>,
65 model_accuracy: F,
67}
68
69#[derive(Debug, Clone)]
71pub struct ErrorRecord<F: Float> {
72 pub predicted_error: F,
74 pub actual_error: F,
76 pub data_characteristics: String,
78 pub method: InterpolationMethodType,
80 pub timestamp: Instant,
82}
83
84#[derive(Debug, Clone)]
86pub struct AccuracyOptimizationResult {
87 pub method: InterpolationMethodType,
89 pub adjusted_parameters: HashMap<String, f64>,
91 pub accuracy_improvement: f64,
93 pub performance_impact: f64,
95 pub success: bool,
97 pub timestamp: Instant,
99}
100
101impl<F: Float + Debug + std::ops::AddAssign> AccuracyOptimizationEngine<F> {
102 pub fn new() -> InterpolateResult<Self> {
104 Ok(Self {
105 strategy: AccuracyOptimizationStrategy::BalancedAccuracy,
106 targets: AccuracyTargets::default(),
107 error_predictor: ErrorPredictionModel::new()?,
108 optimization_history: VecDeque::new(),
109 })
110 }
111
112 pub fn set_strategy(&mut self, strategy: AccuracyOptimizationStrategy) {
114 self.strategy = strategy;
115 }
116
117 pub fn set_targets(&mut self, targets: AccuracyTargets<F>) {
119 self.targets = targets;
120 }
121
122 pub fn optimize_accuracy(
124 &mut self,
125 method: InterpolationMethodType,
126 data_profile: &DataProfile<F>,
127 current_parameters: &HashMap<String, f64>,
128 ) -> InterpolateResult<AccuracyOptimizationResult> {
129 let start_time = Instant::now();
130
131 let predicted_accuracy = self.predict_accuracy(method, data_profile, current_parameters)?;
133
134 if self.meets_accuracy_targets(&predicted_accuracy)? {
136 return Ok(AccuracyOptimizationResult {
137 method,
138 adjusted_parameters: current_parameters.clone(),
139 accuracy_improvement: 0.0,
140 performance_impact: 0.0,
141 success: true,
142 timestamp: start_time,
143 });
144 }
145
146 let optimized_params = match &self.strategy {
148 AccuracyOptimizationStrategy::MaximizeAccuracy => {
149 self.maximize_accuracy_optimization(method, data_profile, current_parameters)?
150 }
151 AccuracyOptimizationStrategy::BalancedAccuracy => {
152 self.balanced_optimization(method, data_profile, current_parameters)?
153 }
154 AccuracyOptimizationStrategy::MinimumAccuracy => {
155 self.minimum_accuracy_optimization(method, data_profile, current_parameters)?
156 }
157 AccuracyOptimizationStrategy::Adaptive => {
158 self.adaptive_optimization(method, data_profile, current_parameters)?
159 }
160 AccuracyOptimizationStrategy::Custom {
161 accuracy_weight,
162 performance_weight,
163 } => self.custom_weighted_optimization(
164 method,
165 data_profile,
166 current_parameters,
167 *accuracy_weight,
168 *performance_weight,
169 )?,
170 };
171
172 let optimized_accuracy = self.predict_accuracy(method, data_profile, &optimized_params)?;
174 let accuracy_improvement = optimized_accuracy
175 .predicted_accuracy
176 .to_f64()
177 .unwrap_or(0.0)
178 - predicted_accuracy
179 .predicted_accuracy
180 .to_f64()
181 .unwrap_or(0.0);
182
183 let performance_impact =
184 self.estimate_performance_impact(&optimized_params, current_parameters);
185
186 let result = AccuracyOptimizationResult {
187 method,
188 adjusted_parameters: optimized_params,
189 accuracy_improvement,
190 performance_impact,
191 success: accuracy_improvement > 0.0,
192 timestamp: start_time,
193 };
194
195 self.optimization_history.push_back(result.clone());
197 if self.optimization_history.len() > 100 {
198 self.optimization_history.pop_front();
199 }
200
201 Ok(result)
202 }
203
204 pub fn predict_accuracy(
206 &self,
207 method: InterpolationMethodType,
208 data_profile: &DataProfile<F>,
209 parameters: &HashMap<String, f64>,
210 ) -> InterpolateResult<AccuracyPrediction<F>> {
211 self.error_predictor
212 .predict_accuracy(method, data_profile, parameters)
213 }
214
215 pub fn update_error_model(
217 &mut self,
218 method: InterpolationMethodType,
219 data_profile: &DataProfile<F>,
220 predicted_error: F,
221 actual_error: F,
222 ) -> InterpolateResult<()> {
223 let error_record = ErrorRecord {
224 predicted_error,
225 actual_error,
226 data_characteristics: format!(
227 "size:{},dim:{}",
228 data_profile.size, data_profile.dimensionality
229 ),
230 method,
231 timestamp: Instant::now(),
232 };
233
234 self.error_predictor.add_error_record(error_record)?;
235 self.error_predictor.update_model()?;
236
237 Ok(())
238 }
239
240 pub fn get_optimization_history(&self) -> &VecDeque<AccuracyOptimizationResult> {
242 &self.optimization_history
243 }
244
245 pub fn get_targets(&self) -> &AccuracyTargets<F> {
247 &self.targets
248 }
249
250 fn meets_accuracy_targets(
252 &self,
253 prediction: &AccuracyPrediction<F>,
254 ) -> InterpolateResult<bool> {
255 let predicted_error = prediction.predicted_accuracy;
256
257 if predicted_error > self.targets.max_acceptable_error {
259 return Ok(false);
260 }
261
262 if let Some(target_abs) = self.targets.target_absolute_error {
264 if predicted_error > target_abs {
265 return Ok(false);
266 }
267 }
268
269 if let Some(target_rel) = self.targets.target_relative_error {
271 if predicted_error > target_rel {
274 return Ok(false);
275 }
276 }
277
278 Ok(true)
279 }
280
281 fn maximize_accuracy_optimization(
283 &self,
284 method: InterpolationMethodType,
285 data_profile: &DataProfile<F>,
286 current_parameters: &HashMap<String, f64>,
287 ) -> InterpolateResult<HashMap<String, f64>> {
288 let mut optimized = current_parameters.clone();
289
290 match method {
291 InterpolationMethodType::CubicSpline => {
292 if let Some(smoothing) = optimized.get_mut("smoothing") {
294 *smoothing *= 0.1;
295 }
296 }
297 InterpolationMethodType::BSpline => {
298 if let Some(degree) = optimized.get_mut("degree") {
300 *degree = (*degree + 1.0).min(5.0);
301 }
302 }
303 InterpolationMethodType::RadialBasisFunction => {
304 if let Some(shape) = optimized.get_mut("shape_parameter") {
306 *shape = self.optimize_rbf_shape_parameter(data_profile);
307 }
308 }
309 _ => {
310 optimized.insert("tolerance".to_string(), 1e-12);
312 optimized.insert("max_iterations".to_string(), 1000.0);
313 }
314 }
315
316 Ok(optimized)
317 }
318
319 fn balanced_optimization(
321 &self,
322 method: InterpolationMethodType,
323 data_profile: &DataProfile<F>,
324 current_parameters: &HashMap<String, f64>,
325 ) -> InterpolateResult<HashMap<String, f64>> {
326 let mut optimized = current_parameters.clone();
327
328 let noise_level = data_profile.noise_level.to_f64().unwrap_or(0.1);
330 let smoothness = data_profile.smoothness.to_f64().unwrap_or(0.5);
331
332 match method {
333 InterpolationMethodType::CubicSpline => {
334 let smoothing_factor = if noise_level > 0.1 {
336 noise_level * 0.5
337 } else {
338 0.01
339 };
340 optimized.insert("smoothing".to_string(), smoothing_factor);
341 }
342 InterpolationMethodType::BSpline => {
343 let degree = if smoothness > 0.8 { 3.0 } else { 2.0 };
345 optimized.insert("degree".to_string(), degree);
346 }
347 _ => {
348 optimized.insert("tolerance".to_string(), 1e-8);
350 optimized.insert("max_iterations".to_string(), 100.0);
351 }
352 }
353
354 Ok(optimized)
355 }
356
357 fn minimum_accuracy_optimization(
359 &self,
360 _method: InterpolationMethodType,
361 _data_profile: &DataProfile<F>,
362 current_parameters: &HashMap<String, f64>,
363 ) -> InterpolateResult<HashMap<String, f64>> {
364 let mut optimized = current_parameters.clone();
365
366 optimized.insert("tolerance".to_string(), 1e-4);
368 optimized.insert("max_iterations".to_string(), 50.0);
369
370 Ok(optimized)
371 }
372
373 fn adaptive_optimization(
375 &self,
376 method: InterpolationMethodType,
377 data_profile: &DataProfile<F>,
378 current_parameters: &HashMap<String, f64>,
379 ) -> InterpolateResult<HashMap<String, f64>> {
380 let noise_level = data_profile.noise_level.to_f64().unwrap_or(0.1);
381 let data_size = data_profile.size;
382
383 if noise_level > 0.2 {
385 self.balanced_optimization(method, data_profile, current_parameters)
387 } else if data_size > 10000 {
388 self.minimum_accuracy_optimization(method, data_profile, current_parameters)
390 } else {
391 self.maximize_accuracy_optimization(method, data_profile, current_parameters)
393 }
394 }
395
396 fn custom_weighted_optimization(
398 &self,
399 method: InterpolationMethodType,
400 data_profile: &DataProfile<F>,
401 current_parameters: &HashMap<String, f64>,
402 accuracy_weight: f64,
403 performance_weight: f64,
404 ) -> InterpolateResult<HashMap<String, f64>> {
405 let accuracy_params =
407 self.maximize_accuracy_optimization(method, data_profile, current_parameters)?;
408 let performance_params =
409 self.minimum_accuracy_optimization(method, data_profile, current_parameters)?;
410
411 let mut optimized = HashMap::new();
412
413 for (key, &acc_val) in &accuracy_params {
415 let perf_val = performance_params.get(key).copied().unwrap_or(acc_val);
416 let weighted_val = accuracy_weight * acc_val + performance_weight * perf_val;
417 optimized.insert(key.clone(), weighted_val);
418 }
419
420 Ok(optimized)
421 }
422
423 fn optimize_rbf_shape_parameter(&self, data_profile: &DataProfile<F>) -> f64 {
425 let typical_distance = (data_profile.value_range.1 - data_profile.value_range.0)
426 .to_f64()
427 .unwrap_or(1.0)
428 / (data_profile.size as f64).sqrt();
429
430 1.0 / typical_distance
432 }
433
434 fn estimate_performance_impact(
436 &self,
437 optimized_params: &HashMap<String, f64>,
438 current_params: &HashMap<String, f64>,
439 ) -> f64 {
440 let mut impact = 0.0;
441
442 if let (Some(&opt_tol), Some(&cur_tol)) = (
444 optimized_params.get("tolerance"),
445 current_params.get("tolerance"),
446 ) {
447 if opt_tol < cur_tol {
448 impact += (cur_tol / opt_tol).log10() * 0.1; }
450 }
451
452 if let (Some(&opt_iter), Some(&cur_iter)) = (
454 optimized_params.get("max_iterations"),
455 current_params.get("max_iterations"),
456 ) {
457 impact += (opt_iter / cur_iter - 1.0) * 0.05; }
459
460 if let (Some(&opt_deg), Some(&cur_deg)) =
462 (optimized_params.get("degree"), current_params.get("degree"))
463 {
464 impact += (opt_deg - cur_deg) * 0.15; }
466
467 impact.max(-0.5).min(2.0) }
469}
470
471impl<F: Float> Default for AccuracyTargets<F> {
472 fn default() -> Self {
473 Self {
474 target_absolute_error: None,
475 target_relative_error: None,
476 max_acceptable_error: F::from(1e-6).unwrap(),
477 confidence_level: F::from(0.95).unwrap(),
478 }
479 }
480}
481
482impl<F: Float + std::ops::AddAssign> ErrorPredictionModel<F> {
483 pub fn new() -> InterpolateResult<Self> {
485 Ok(Self {
486 prediction_params: HashMap::new(),
487 error_history: VecDeque::new(),
488 model_accuracy: F::from(0.8).unwrap(),
489 })
490 }
491
492 pub fn predict_accuracy(
494 &self,
495 method: InterpolationMethodType,
496 data_profile: &DataProfile<F>,
497 _parameters: &HashMap<String, f64>,
498 ) -> InterpolateResult<AccuracyPrediction<F>> {
499 let base_accuracy = self.get_base_accuracy(method);
501 let noise_penalty = data_profile.noise_level.to_f64().unwrap_or(0.1) * 0.5;
502 let size_bonus = if data_profile.size > 1000 { 0.05 } else { 0.0 };
503
504 let predicted_error = F::from(1.0 - base_accuracy + noise_penalty - size_bonus).unwrap();
505 let confidence = self.model_accuracy;
506
507 Ok(AccuracyPrediction {
508 predicted_accuracy: predicted_error.max(F::from(1e-12).unwrap()),
509 confidence_interval: (
510 predicted_error * F::from(0.8).unwrap(),
511 predicted_error * F::from(1.2).unwrap(),
512 ),
513 prediction_confidence: confidence,
514 accuracy_factors: vec![
515 AccuracyFactor {
516 name: "Method capability".to_string(),
517 impact: F::from(base_accuracy - 0.5).unwrap(),
518 confidence: F::from(0.9).unwrap(),
519 mitigations: vec!["Consider higher-order methods".to_string()],
520 },
521 AccuracyFactor {
522 name: "Data noise level".to_string(),
523 impact: F::from(-noise_penalty).unwrap(),
524 confidence: F::from(0.8).unwrap(),
525 mitigations: vec![
526 "Apply data smoothing".to_string(),
527 "Use robust methods".to_string(),
528 ],
529 },
530 ],
531 })
532 }
533
534 pub fn add_error_record(&mut self, record: ErrorRecord<F>) -> InterpolateResult<()> {
536 self.error_history.push_back(record);
537
538 if self.error_history.len() > 1000 {
540 self.error_history.pop_front();
541 }
542
543 Ok(())
544 }
545
546 pub fn update_model(&mut self) -> InterpolateResult<()> {
548 if self.error_history.len() < 10 {
549 return Ok(()); }
551
552 let recent_records: Vec<_> = self.error_history.iter().rev().take(50).collect();
554 let mut total_error = F::zero();
555 let mut count = 0;
556
557 for record in recent_records {
558 let relative_error =
559 (record.predicted_error - record.actual_error).abs() / record.actual_error;
560 total_error += relative_error;
561 count += 1;
562 }
563
564 if count > 0 {
565 let avg_relative_error = total_error / F::from(count).unwrap();
566 self.model_accuracy = (F::one() - avg_relative_error).max(F::from(0.1).unwrap());
567 }
568
569 Ok(())
570 }
571
572 fn get_base_accuracy(&self, method: InterpolationMethodType) -> f64 {
574 match method {
575 InterpolationMethodType::Linear => 0.7,
576 InterpolationMethodType::CubicSpline => 0.9,
577 InterpolationMethodType::BSpline => 0.92,
578 InterpolationMethodType::RadialBasisFunction => 0.95,
579 InterpolationMethodType::Kriging => 0.98,
580 InterpolationMethodType::Polynomial => 0.85,
581 InterpolationMethodType::PchipInterpolation => 0.88,
582 InterpolationMethodType::AkimaSpline => 0.87,
583 InterpolationMethodType::ThinPlateSpline => 0.93,
584 InterpolationMethodType::NaturalNeighbor => 0.86,
585 InterpolationMethodType::ShepardsMethod => 0.75,
586 InterpolationMethodType::QuantumInspired => 0.99,
587 }
588 }
589
590 pub fn get_model_accuracy(&self) -> F {
592 self.model_accuracy
593 }
594
595 pub fn get_prediction_statistics(&self) -> HashMap<String, f64> {
597 let mut stats = HashMap::new();
598
599 if !self.error_history.is_empty() {
600 let mut total_abs_error = F::zero();
601 let mut total_rel_error = F::zero();
602 let count = self.error_history.len();
603
604 for record in &self.error_history {
605 let abs_error = (record.predicted_error - record.actual_error).abs();
606 let rel_error = abs_error / record.actual_error;
607
608 total_abs_error += abs_error;
609 total_rel_error += rel_error;
610 }
611
612 stats.insert(
613 "mean_absolute_error".to_string(),
614 (total_abs_error / F::from(count).unwrap())
615 .to_f64()
616 .unwrap_or(0.0),
617 );
618 stats.insert(
619 "mean_relative_error".to_string(),
620 (total_rel_error / F::from(count).unwrap())
621 .to_f64()
622 .unwrap_or(0.0),
623 );
624 stats.insert(
625 "model_accuracy".to_string(),
626 self.model_accuracy.to_f64().unwrap_or(0.0),
627 );
628 stats.insert("sample_count".to_string(), count as f64);
629 }
630
631 stats
632 }
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638
639 #[test]
640 fn test_accuracy_targets_default() {
641 let targets: AccuracyTargets<f64> = AccuracyTargets::default();
642 assert_eq!(targets.max_acceptable_error, 1e-6);
643 assert_eq!(targets.confidence_level, 0.95);
644 assert!(targets.target_absolute_error.is_none());
645 }
646
647 #[test]
648 fn test_error_prediction_model_creation() {
649 let model: ErrorPredictionModel<f64> = ErrorPredictionModel::new().unwrap();
650 assert_eq!(model.model_accuracy, 0.8);
651 assert!(model.error_history.is_empty());
652 }
653
654 #[test]
655 fn test_accuracy_optimization_engine_creation() {
656 let engine: AccuracyOptimizationEngine<f64> = AccuracyOptimizationEngine::new().unwrap();
657 assert!(matches!(
658 engine.strategy,
659 AccuracyOptimizationStrategy::BalancedAccuracy
660 ));
661 assert!(engine.optimization_history.is_empty());
662 }
663}