1use sklears_core::types::Float;
15use std::any::Any;
16use std::collections::HashMap;
17use std::sync::{Arc, RwLock};
18
19pub trait OptimizerPlugin: Send + Sync {
25 fn name(&self) -> &str;
27
28 fn version(&self) -> &str;
30
31 fn description(&self) -> &str;
33
34 fn initialize(&mut self, config: &PluginConfig) -> Result<(), PluginError>;
36
37 fn suggest(
39 &mut self,
40 history: &OptimizationHistory,
41 constraints: &ParameterConstraints,
42 ) -> Result<HashMap<String, Float>, PluginError>;
43
44 fn observe(
46 &mut self,
47 parameters: &HashMap<String, Float>,
48 objective_value: Float,
49 metadata: Option<&HashMap<String, String>>,
50 ) -> Result<(), PluginError>;
51
52 fn should_stop(&self, history: &OptimizationHistory) -> bool;
54
55 fn get_statistics(&self) -> Result<HashMap<String, Float>, PluginError>;
57
58 fn shutdown(&mut self) -> Result<(), PluginError>;
60
61 fn as_any(&self) -> &dyn Any;
63
64 fn as_any_mut(&mut self) -> &mut dyn Any;
66}
67
68#[derive(Debug, Clone)]
70pub struct PluginConfig {
71 pub max_iterations: usize,
72 pub random_seed: Option<u64>,
73 pub parallel: bool,
74 pub custom_params: HashMap<String, String>,
75}
76
77impl Default for PluginConfig {
78 fn default() -> Self {
79 Self {
80 max_iterations: 100,
81 random_seed: None,
82 parallel: false,
83 custom_params: HashMap::new(),
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct OptimizationHistory {
91 pub evaluations: Vec<Evaluation>,
92 pub best_value: Float,
93 pub best_parameters: HashMap<String, Float>,
94 pub n_evaluations: usize,
95}
96
97impl OptimizationHistory {
98 pub fn new() -> Self {
99 Self {
100 evaluations: Vec::new(),
101 best_value: f64::NEG_INFINITY,
102 best_parameters: HashMap::new(),
103 n_evaluations: 0,
104 }
105 }
106
107 pub fn add_evaluation(&mut self, params: HashMap<String, Float>, value: Float) {
108 self.evaluations.push(Evaluation {
109 parameters: params.clone(),
110 objective_value: value,
111 iteration: self.n_evaluations,
112 });
113
114 if value > self.best_value {
115 self.best_value = value;
116 self.best_parameters = params;
117 }
118
119 self.n_evaluations += 1;
120 }
121}
122
123impl Default for OptimizationHistory {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct Evaluation {
132 pub parameters: HashMap<String, Float>,
133 pub objective_value: Float,
134 pub iteration: usize,
135}
136
137#[derive(Debug, Clone, Default)]
139pub struct ParameterConstraints {
140 pub bounds: HashMap<String, (Float, Float)>,
141 pub integer_params: Vec<String>,
142 pub categorical_params: HashMap<String, Vec<String>>,
143}
144
145#[derive(Debug, Clone)]
147pub enum PluginError {
148 InitializationFailed(String),
149 SuggestionFailed(String),
150 ObservationFailed(String),
151 InvalidConfiguration(String),
152 NotInitialized,
153 AlreadyInitialized,
154 InternalError(String),
155}
156
157impl std::fmt::Display for PluginError {
158 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
159 match self {
160 PluginError::InitializationFailed(msg) => write!(f, "Initialization failed: {}", msg),
161 PluginError::SuggestionFailed(msg) => write!(f, "Suggestion failed: {}", msg),
162 PluginError::ObservationFailed(msg) => write!(f, "Observation failed: {}", msg),
163 PluginError::InvalidConfiguration(msg) => write!(f, "Invalid configuration: {}", msg),
164 PluginError::NotInitialized => write!(f, "Plugin not initialized"),
165 PluginError::AlreadyInitialized => write!(f, "Plugin already initialized"),
166 PluginError::InternalError(msg) => write!(f, "Internal error: {}", msg),
167 }
168 }
169}
170
171impl std::error::Error for PluginError {}
172
173pub struct PluginRegistry {
179 plugins: Arc<RwLock<HashMap<String, Box<dyn OptimizerPlugin>>>>,
180 factories: Arc<RwLock<HashMap<String, PluginFactory>>>,
181}
182
183impl PluginRegistry {
184 pub fn new() -> Self {
185 Self {
186 plugins: Arc::new(RwLock::new(HashMap::new())),
187 factories: Arc::new(RwLock::new(HashMap::new())),
188 }
189 }
190
191 pub fn register_factory(
193 &self,
194 name: String,
195 factory: PluginFactory,
196 ) -> Result<(), PluginError> {
197 let mut factories = self
198 .factories
199 .write()
200 .map_err(|e| PluginError::InternalError(format!("Failed to lock registry: {}", e)))?;
201
202 if factories.contains_key(&name) {
203 return Err(PluginError::AlreadyInitialized);
204 }
205
206 factories.insert(name, factory);
207 Ok(())
208 }
209
210 pub fn create_plugin(&self, name: &str, config: &PluginConfig) -> Result<(), PluginError> {
212 let factories = self
213 .factories
214 .read()
215 .map_err(|e| PluginError::InternalError(format!("Failed to lock registry: {}", e)))?;
216
217 let factory = factories.get(name).ok_or_else(|| {
218 PluginError::InitializationFailed(format!("Plugin '{}' not found", name))
219 })?;
220
221 let mut plugin = (factory.create)()?;
222 plugin.initialize(config)?;
223
224 let mut plugins = self
225 .plugins
226 .write()
227 .map_err(|e| PluginError::InternalError(format!("Failed to lock plugins: {}", e)))?;
228
229 plugins.insert(name.to_string(), plugin);
230 Ok(())
231 }
232
233 pub fn get_plugin(&self, name: &str) -> Result<Box<dyn OptimizerPlugin>, PluginError> {
235 let plugins = self
236 .plugins
237 .read()
238 .map_err(|e| PluginError::InternalError(format!("Failed to lock plugins: {}", e)))?;
239
240 plugins
241 .get(name)
242 .ok_or_else(|| {
243 PluginError::InitializationFailed(format!("Plugin '{}' not found", name))
244 })
245 .map(|_plugin| {
246 Err(PluginError::InternalError(
249 "Cannot borrow plugin".to_string(),
250 ))
251 })?
252 }
253
254 pub fn list_plugins(&self) -> Result<Vec<String>, PluginError> {
256 let factories = self
257 .factories
258 .read()
259 .map_err(|e| PluginError::InternalError(format!("Failed to lock registry: {}", e)))?;
260
261 Ok(factories.keys().cloned().collect())
262 }
263
264 pub fn unregister_plugin(&self, name: &str) -> Result<(), PluginError> {
266 let mut plugins = self
267 .plugins
268 .write()
269 .map_err(|e| PluginError::InternalError(format!("Failed to lock plugins: {}", e)))?;
270
271 if let Some(mut plugin) = plugins.remove(name) {
272 plugin.shutdown()?;
273 }
274
275 Ok(())
276 }
277}
278
279impl Default for PluginRegistry {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285pub struct PluginFactory {
287 pub name: String,
288 pub version: String,
289 pub description: String,
290 pub create: fn() -> Result<Box<dyn OptimizerPlugin>, PluginError>,
291}
292
293pub trait OptimizationHook: Send + Sync {
299 fn on_optimization_start(
301 &mut self,
302 config: &PluginConfig,
303 constraints: &ParameterConstraints,
304 ) -> Result<(), HookError>;
305
306 fn on_iteration_start(
308 &mut self,
309 iteration: usize,
310 history: &OptimizationHistory,
311 ) -> Result<(), HookError>;
312
313 fn on_evaluation(
315 &mut self,
316 parameters: &HashMap<String, Float>,
317 objective_value: Float,
318 iteration: usize,
319 ) -> Result<(), HookError>;
320
321 fn on_iteration_end(
323 &mut self,
324 iteration: usize,
325 history: &OptimizationHistory,
326 ) -> Result<(), HookError>;
327
328 fn on_optimization_end(
330 &mut self,
331 history: &OptimizationHistory,
332 reason: StopReason,
333 ) -> Result<(), HookError>;
334
335 fn on_error(&mut self, error: &dyn std::error::Error) -> Result<(), HookError>;
337}
338
339#[derive(Debug, Clone)]
341pub enum StopReason {
342 MaxIterationsReached,
343 ConvergenceReached,
344 TimeoutReached,
345 UserInterrupted,
346 PluginDecision,
347 Error(String),
348}
349
350#[derive(Debug, Clone)]
352pub enum HookError {
353 ExecutionFailed(String),
354 InvalidState(String),
355}
356
357impl std::fmt::Display for HookError {
358 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
359 match self {
360 HookError::ExecutionFailed(msg) => write!(f, "Hook execution failed: {}", msg),
361 HookError::InvalidState(msg) => write!(f, "Invalid state: {}", msg),
362 }
363 }
364}
365
366impl std::error::Error for HookError {}
367
368pub struct HookManager {
370 hooks: Vec<Box<dyn OptimizationHook>>,
371}
372
373impl HookManager {
374 pub fn new() -> Self {
375 Self { hooks: Vec::new() }
376 }
377
378 pub fn add_hook(&mut self, hook: Box<dyn OptimizationHook>) {
379 self.hooks.push(hook);
380 }
381
382 pub fn trigger_optimization_start(
383 &mut self,
384 config: &PluginConfig,
385 constraints: &ParameterConstraints,
386 ) -> Result<(), HookError> {
387 for hook in &mut self.hooks {
388 hook.on_optimization_start(config, constraints)?;
389 }
390 Ok(())
391 }
392
393 pub fn trigger_iteration_start(
394 &mut self,
395 iteration: usize,
396 history: &OptimizationHistory,
397 ) -> Result<(), HookError> {
398 for hook in &mut self.hooks {
399 hook.on_iteration_start(iteration, history)?;
400 }
401 Ok(())
402 }
403
404 pub fn trigger_evaluation(
405 &mut self,
406 parameters: &HashMap<String, Float>,
407 objective_value: Float,
408 iteration: usize,
409 ) -> Result<(), HookError> {
410 for hook in &mut self.hooks {
411 hook.on_evaluation(parameters, objective_value, iteration)?;
412 }
413 Ok(())
414 }
415
416 pub fn trigger_iteration_end(
417 &mut self,
418 iteration: usize,
419 history: &OptimizationHistory,
420 ) -> Result<(), HookError> {
421 for hook in &mut self.hooks {
422 hook.on_iteration_end(iteration, history)?;
423 }
424 Ok(())
425 }
426
427 pub fn trigger_optimization_end(
428 &mut self,
429 history: &OptimizationHistory,
430 reason: StopReason,
431 ) -> Result<(), HookError> {
432 for hook in &mut self.hooks {
433 hook.on_optimization_end(history, reason.clone())?;
434 }
435 Ok(())
436 }
437
438 pub fn trigger_error(&mut self, error: &dyn std::error::Error) -> Result<(), HookError> {
439 for hook in &mut self.hooks {
440 hook.on_error(error)?;
441 }
442 Ok(())
443 }
444}
445
446impl Default for HookManager {
447 fn default() -> Self {
448 Self::new()
449 }
450}
451
452pub trait OptimizationMiddleware: Send + Sync {
458 fn process_suggestion(
460 &self,
461 parameters: &mut HashMap<String, Float>,
462 history: &OptimizationHistory,
463 ) -> Result<(), MiddlewareError>;
464
465 fn process_observation(
467 &self,
468 parameters: &HashMap<String, Float>,
469 objective_value: &mut Float,
470 history: &OptimizationHistory,
471 ) -> Result<(), MiddlewareError>;
472
473 fn name(&self) -> &str;
475}
476
477#[derive(Debug, Clone)]
479pub enum MiddlewareError {
480 ProcessingFailed(String),
481 ValidationFailed(String),
482}
483
484impl std::fmt::Display for MiddlewareError {
485 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
486 match self {
487 MiddlewareError::ProcessingFailed(msg) => write!(f, "Processing failed: {}", msg),
488 MiddlewareError::ValidationFailed(msg) => write!(f, "Validation failed: {}", msg),
489 }
490 }
491}
492
493impl std::error::Error for MiddlewareError {}
494
495pub struct MiddlewarePipeline {
497 middleware: Vec<Box<dyn OptimizationMiddleware>>,
498}
499
500impl MiddlewarePipeline {
501 pub fn new() -> Self {
502 Self {
503 middleware: Vec::new(),
504 }
505 }
506
507 pub fn add(&mut self, middleware: Box<dyn OptimizationMiddleware>) {
508 self.middleware.push(middleware);
509 }
510
511 pub fn process_suggestion(
512 &self,
513 parameters: &mut HashMap<String, Float>,
514 history: &OptimizationHistory,
515 ) -> Result<(), MiddlewareError> {
516 for m in &self.middleware {
517 m.process_suggestion(parameters, history)?;
518 }
519 Ok(())
520 }
521
522 pub fn process_observation(
523 &self,
524 parameters: &HashMap<String, Float>,
525 objective_value: &mut Float,
526 history: &OptimizationHistory,
527 ) -> Result<(), MiddlewareError> {
528 for m in &self.middleware {
529 m.process_observation(parameters, objective_value, history)?;
530 }
531 Ok(())
532 }
533}
534
535impl Default for MiddlewarePipeline {
536 fn default() -> Self {
537 Self::new()
538 }
539}
540
541pub trait CustomMetric: Send + Sync {
547 fn name(&self) -> &str;
549
550 fn compute(
552 &self,
553 parameters: &HashMap<String, Float>,
554 objective_value: Float,
555 history: &OptimizationHistory,
556 ) -> Result<Float, MetricError>;
557
558 fn higher_is_better(&self) -> bool;
560}
561
562#[derive(Debug, Clone)]
564pub enum MetricError {
565 ComputationFailed(String),
566 InvalidInput(String),
567}
568
569impl std::fmt::Display for MetricError {
570 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
571 match self {
572 MetricError::ComputationFailed(msg) => write!(f, "Computation failed: {}", msg),
573 MetricError::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
574 }
575 }
576}
577
578impl std::error::Error for MetricError {}
579
580pub struct MetricRegistry {
582 metrics: HashMap<String, Box<dyn CustomMetric>>,
583}
584
585impl MetricRegistry {
586 pub fn new() -> Self {
587 Self {
588 metrics: HashMap::new(),
589 }
590 }
591
592 pub fn register(&mut self, metric: Box<dyn CustomMetric>) -> Result<(), MetricError> {
593 let name = metric.name().to_string();
594 if self.metrics.contains_key(&name) {
595 return Err(MetricError::InvalidInput(format!(
596 "Metric '{}' already registered",
597 name
598 )));
599 }
600 self.metrics.insert(name, metric);
601 Ok(())
602 }
603
604 pub fn compute(
605 &self,
606 metric_name: &str,
607 parameters: &HashMap<String, Float>,
608 objective_value: Float,
609 history: &OptimizationHistory,
610 ) -> Result<Float, MetricError> {
611 let metric = self.metrics.get(metric_name).ok_or_else(|| {
612 MetricError::InvalidInput(format!("Metric '{}' not found", metric_name))
613 })?;
614
615 metric.compute(parameters, objective_value, history)
616 }
617
618 pub fn list_metrics(&self) -> Vec<String> {
619 self.metrics.keys().cloned().collect()
620 }
621}
622
623impl Default for MetricRegistry {
624 fn default() -> Self {
625 Self::new()
626 }
627}
628
629pub struct LoggingHook {
635 log_interval: usize,
636}
637
638impl LoggingHook {
639 pub fn new(log_interval: usize) -> Self {
640 Self { log_interval }
641 }
642}
643
644impl OptimizationHook for LoggingHook {
645 fn on_optimization_start(
646 &mut self,
647 _config: &PluginConfig,
648 _constraints: &ParameterConstraints,
649 ) -> Result<(), HookError> {
650 println!("Optimization started");
651 Ok(())
652 }
653
654 fn on_iteration_start(
655 &mut self,
656 iteration: usize,
657 _history: &OptimizationHistory,
658 ) -> Result<(), HookError> {
659 if iteration % self.log_interval == 0 {
660 println!("Starting iteration {}", iteration);
661 }
662 Ok(())
663 }
664
665 fn on_evaluation(
666 &mut self,
667 _parameters: &HashMap<String, Float>,
668 _objective_value: Float,
669 _iteration: usize,
670 ) -> Result<(), HookError> {
671 Ok(())
672 }
673
674 fn on_iteration_end(
675 &mut self,
676 iteration: usize,
677 history: &OptimizationHistory,
678 ) -> Result<(), HookError> {
679 if iteration % self.log_interval == 0 {
680 println!(
681 "Iteration {} complete. Best so far: {}",
682 iteration, history.best_value
683 );
684 }
685 Ok(())
686 }
687
688 fn on_optimization_end(
689 &mut self,
690 history: &OptimizationHistory,
691 reason: StopReason,
692 ) -> Result<(), HookError> {
693 println!("Optimization ended. Reason: {:?}", reason);
694 println!("Best value: {}", history.best_value);
695 println!("Total evaluations: {}", history.n_evaluations);
696 Ok(())
697 }
698
699 fn on_error(&mut self, error: &dyn std::error::Error) -> Result<(), HookError> {
700 eprintln!("Error during optimization: {}", error);
701 Ok(())
702 }
703}
704
705pub struct NormalizationMiddleware {
707 name: String,
708}
709
710impl NormalizationMiddleware {
711 pub fn new() -> Self {
712 Self {
713 name: "normalization".to_string(),
714 }
715 }
716}
717
718impl Default for NormalizationMiddleware {
719 fn default() -> Self {
720 Self::new()
721 }
722}
723
724impl OptimizationMiddleware for NormalizationMiddleware {
725 fn process_suggestion(
726 &self,
727 parameters: &mut HashMap<String, Float>,
728 _history: &OptimizationHistory,
729 ) -> Result<(), MiddlewareError> {
730 for (_, value) in parameters.iter_mut() {
732 if value.is_nan() || value.is_infinite() {
733 *value = 0.0; }
735 }
736 Ok(())
737 }
738
739 fn process_observation(
740 &self,
741 _parameters: &HashMap<String, Float>,
742 objective_value: &mut Float,
743 _history: &OptimizationHistory,
744 ) -> Result<(), MiddlewareError> {
745 if objective_value.is_nan() || objective_value.is_infinite() {
747 *objective_value = f64::NEG_INFINITY;
748 }
749 Ok(())
750 }
751
752 fn name(&self) -> &str {
753 &self.name
754 }
755}
756
757#[cfg(test)]
762mod tests {
763 use super::*;
764
765 #[test]
766 fn test_plugin_config() {
767 let config = PluginConfig::default();
768 assert_eq!(config.max_iterations, 100);
769 assert!(!config.parallel);
770 }
771
772 #[test]
773 fn test_optimization_history() {
774 let mut history = OptimizationHistory::new();
775 assert_eq!(history.n_evaluations, 0);
776
777 let mut params = HashMap::new();
778 params.insert("lr".to_string(), 0.01);
779 history.add_evaluation(params, 0.95);
780
781 assert_eq!(history.n_evaluations, 1);
782 assert_eq!(history.best_value, 0.95);
783 }
784
785 #[test]
786 fn test_plugin_registry() {
787 let registry = PluginRegistry::new();
788 assert!(registry.list_plugins().is_ok());
789 }
790
791 #[test]
792 fn test_hook_manager() {
793 let mut manager = HookManager::new();
794 let hook = Box::new(LoggingHook::new(10));
795 manager.add_hook(hook);
796
797 let config = PluginConfig::default();
798 let constraints = ParameterConstraints::default();
799 assert!(manager
800 .trigger_optimization_start(&config, &constraints)
801 .is_ok());
802 }
803
804 #[test]
805 fn test_middleware_pipeline() {
806 let mut pipeline = MiddlewarePipeline::new();
807 pipeline.add(Box::new(NormalizationMiddleware::new()));
808
809 let mut params = HashMap::new();
810 params.insert("x".to_string(), f64::NAN);
811
812 let history = OptimizationHistory::new();
813 assert!(pipeline.process_suggestion(&mut params, &history).is_ok());
814 assert_eq!(params.get("x"), Some(&0.0));
815 }
816
817 #[test]
818 fn test_metric_registry() {
819 let registry = MetricRegistry::new();
820 assert!(registry.list_metrics().is_empty());
821 }
822}