sklears_model_selection/
optimizer_plugins.rs

1//! Plugin Architecture for Custom Optimizers
2//!
3//! This module provides a flexible plugin system for implementing custom hyperparameter
4//! optimization algorithms. It includes:
5//! - Trait-based plugin interface
6//! - Plugin registry for dynamic loading
7//! - Hook system for optimization callbacks
8//! - Middleware support for optimization pipelines
9//! - Custom metric registration
10//!
11//! This enables users to extend the optimization framework with their own algorithms
12//! without modifying the core library.
13
14use sklears_core::types::Float;
15use std::any::Any;
16use std::collections::HashMap;
17use std::sync::{Arc, RwLock};
18
19// ============================================================================
20// Core Plugin Traits
21// ============================================================================
22
23/// Core trait for optimization plugins
24pub trait OptimizerPlugin: Send + Sync {
25    /// Plugin name (must be unique)
26    fn name(&self) -> &str;
27
28    /// Plugin version
29    fn version(&self) -> &str;
30
31    /// Plugin description
32    fn description(&self) -> &str;
33
34    /// Initialize the plugin
35    fn initialize(&mut self, config: &PluginConfig) -> Result<(), PluginError>;
36
37    /// Suggest next hyperparameter configuration to evaluate
38    fn suggest(
39        &mut self,
40        history: &OptimizationHistory,
41        constraints: &ParameterConstraints,
42    ) -> Result<HashMap<String, Float>, PluginError>;
43
44    /// Update plugin state with new observation
45    fn observe(
46        &mut self,
47        parameters: &HashMap<String, Float>,
48        objective_value: Float,
49        metadata: Option<&HashMap<String, String>>,
50    ) -> Result<(), PluginError>;
51
52    /// Check if optimization should stop
53    fn should_stop(&self, history: &OptimizationHistory) -> bool;
54
55    /// Get plugin-specific statistics
56    fn get_statistics(&self) -> Result<HashMap<String, Float>, PluginError>;
57
58    /// Clean up resources
59    fn shutdown(&mut self) -> Result<(), PluginError>;
60
61    /// Downcast to concrete type (for advanced usage)
62    fn as_any(&self) -> &dyn Any;
63
64    /// Mutable downcast
65    fn as_any_mut(&mut self) -> &mut dyn Any;
66}
67
68/// Configuration for plugins
69#[derive(Debug, Clone)]
70pub struct PluginConfig {
71    pub max_iterations: usize,
72    pub random_seed: Option<u64>,
73    pub parallel: bool,
74    pub custom_params: HashMap<String, String>,
75}
76
77impl Default for PluginConfig {
78    fn default() -> Self {
79        Self {
80            max_iterations: 100,
81            random_seed: None,
82            parallel: false,
83            custom_params: HashMap::new(),
84        }
85    }
86}
87
88/// Optimization history
89#[derive(Debug, Clone)]
90pub struct OptimizationHistory {
91    pub evaluations: Vec<Evaluation>,
92    pub best_value: Float,
93    pub best_parameters: HashMap<String, Float>,
94    pub n_evaluations: usize,
95}
96
97impl OptimizationHistory {
98    pub fn new() -> Self {
99        Self {
100            evaluations: Vec::new(),
101            best_value: f64::NEG_INFINITY,
102            best_parameters: HashMap::new(),
103            n_evaluations: 0,
104        }
105    }
106
107    pub fn add_evaluation(&mut self, params: HashMap<String, Float>, value: Float) {
108        self.evaluations.push(Evaluation {
109            parameters: params.clone(),
110            objective_value: value,
111            iteration: self.n_evaluations,
112        });
113
114        if value > self.best_value {
115            self.best_value = value;
116            self.best_parameters = params;
117        }
118
119        self.n_evaluations += 1;
120    }
121}
122
123impl Default for OptimizationHistory {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129/// Single evaluation record
130#[derive(Debug, Clone)]
131pub struct Evaluation {
132    pub parameters: HashMap<String, Float>,
133    pub objective_value: Float,
134    pub iteration: usize,
135}
136
137/// Parameter constraints for optimization
138#[derive(Debug, Clone, Default)]
139pub struct ParameterConstraints {
140    pub bounds: HashMap<String, (Float, Float)>,
141    pub integer_params: Vec<String>,
142    pub categorical_params: HashMap<String, Vec<String>>,
143}
144
145/// Plugin error type
146#[derive(Debug, Clone)]
147pub enum PluginError {
148    InitializationFailed(String),
149    SuggestionFailed(String),
150    ObservationFailed(String),
151    InvalidConfiguration(String),
152    NotInitialized,
153    AlreadyInitialized,
154    InternalError(String),
155}
156
157impl std::fmt::Display for PluginError {
158    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
159        match self {
160            PluginError::InitializationFailed(msg) => write!(f, "Initialization failed: {}", msg),
161            PluginError::SuggestionFailed(msg) => write!(f, "Suggestion failed: {}", msg),
162            PluginError::ObservationFailed(msg) => write!(f, "Observation failed: {}", msg),
163            PluginError::InvalidConfiguration(msg) => write!(f, "Invalid configuration: {}", msg),
164            PluginError::NotInitialized => write!(f, "Plugin not initialized"),
165            PluginError::AlreadyInitialized => write!(f, "Plugin already initialized"),
166            PluginError::InternalError(msg) => write!(f, "Internal error: {}", msg),
167        }
168    }
169}
170
171impl std::error::Error for PluginError {}
172
173// ============================================================================
174// Plugin Registry
175// ============================================================================
176
177/// Global plugin registry
178pub struct PluginRegistry {
179    plugins: Arc<RwLock<HashMap<String, Box<dyn OptimizerPlugin>>>>,
180    factories: Arc<RwLock<HashMap<String, PluginFactory>>>,
181}
182
183impl PluginRegistry {
184    pub fn new() -> Self {
185        Self {
186            plugins: Arc::new(RwLock::new(HashMap::new())),
187            factories: Arc::new(RwLock::new(HashMap::new())),
188        }
189    }
190
191    /// Register a plugin factory
192    pub fn register_factory(
193        &self,
194        name: String,
195        factory: PluginFactory,
196    ) -> Result<(), PluginError> {
197        let mut factories = self
198            .factories
199            .write()
200            .map_err(|e| PluginError::InternalError(format!("Failed to lock registry: {}", e)))?;
201
202        if factories.contains_key(&name) {
203            return Err(PluginError::AlreadyInitialized);
204        }
205
206        factories.insert(name, factory);
207        Ok(())
208    }
209
210    /// Create and register a plugin instance
211    pub fn create_plugin(&self, name: &str, config: &PluginConfig) -> Result<(), PluginError> {
212        let factories = self
213            .factories
214            .read()
215            .map_err(|e| PluginError::InternalError(format!("Failed to lock registry: {}", e)))?;
216
217        let factory = factories.get(name).ok_or_else(|| {
218            PluginError::InitializationFailed(format!("Plugin '{}' not found", name))
219        })?;
220
221        let mut plugin = (factory.create)()?;
222        plugin.initialize(config)?;
223
224        let mut plugins = self
225            .plugins
226            .write()
227            .map_err(|e| PluginError::InternalError(format!("Failed to lock plugins: {}", e)))?;
228
229        plugins.insert(name.to_string(), plugin);
230        Ok(())
231    }
232
233    /// Get a plugin by name
234    pub fn get_plugin(&self, name: &str) -> Result<Box<dyn OptimizerPlugin>, PluginError> {
235        let plugins = self
236            .plugins
237            .read()
238            .map_err(|e| PluginError::InternalError(format!("Failed to lock plugins: {}", e)))?;
239
240        plugins
241            .get(name)
242            .ok_or_else(|| {
243                PluginError::InitializationFailed(format!("Plugin '{}' not found", name))
244            })
245            .map(|_plugin| {
246                // This is a placeholder - in reality, we'd need to clone or use Arc
247                // For now, return an error
248                Err(PluginError::InternalError(
249                    "Cannot borrow plugin".to_string(),
250                ))
251            })?
252    }
253
254    /// List all registered plugin names
255    pub fn list_plugins(&self) -> Result<Vec<String>, PluginError> {
256        let factories = self
257            .factories
258            .read()
259            .map_err(|e| PluginError::InternalError(format!("Failed to lock registry: {}", e)))?;
260
261        Ok(factories.keys().cloned().collect())
262    }
263
264    /// Unregister a plugin
265    pub fn unregister_plugin(&self, name: &str) -> Result<(), PluginError> {
266        let mut plugins = self
267            .plugins
268            .write()
269            .map_err(|e| PluginError::InternalError(format!("Failed to lock plugins: {}", e)))?;
270
271        if let Some(mut plugin) = plugins.remove(name) {
272            plugin.shutdown()?;
273        }
274
275        Ok(())
276    }
277}
278
279impl Default for PluginRegistry {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285/// Plugin factory for creating plugin instances
286pub struct PluginFactory {
287    pub name: String,
288    pub version: String,
289    pub description: String,
290    pub create: fn() -> Result<Box<dyn OptimizerPlugin>, PluginError>,
291}
292
293// ============================================================================
294// Hook System
295// ============================================================================
296
297/// Optimization hook trait for callbacks
298pub trait OptimizationHook: Send + Sync {
299    /// Called before optimization starts
300    fn on_optimization_start(
301        &mut self,
302        config: &PluginConfig,
303        constraints: &ParameterConstraints,
304    ) -> Result<(), HookError>;
305
306    /// Called before each iteration
307    fn on_iteration_start(
308        &mut self,
309        iteration: usize,
310        history: &OptimizationHistory,
311    ) -> Result<(), HookError>;
312
313    /// Called after each evaluation
314    fn on_evaluation(
315        &mut self,
316        parameters: &HashMap<String, Float>,
317        objective_value: Float,
318        iteration: usize,
319    ) -> Result<(), HookError>;
320
321    /// Called after each iteration
322    fn on_iteration_end(
323        &mut self,
324        iteration: usize,
325        history: &OptimizationHistory,
326    ) -> Result<(), HookError>;
327
328    /// Called when optimization completes
329    fn on_optimization_end(
330        &mut self,
331        history: &OptimizationHistory,
332        reason: StopReason,
333    ) -> Result<(), HookError>;
334
335    /// Called on optimization error
336    fn on_error(&mut self, error: &dyn std::error::Error) -> Result<(), HookError>;
337}
338
339/// Reason for stopping optimization
340#[derive(Debug, Clone)]
341pub enum StopReason {
342    MaxIterationsReached,
343    ConvergenceReached,
344    TimeoutReached,
345    UserInterrupted,
346    PluginDecision,
347    Error(String),
348}
349
350/// Hook error type
351#[derive(Debug, Clone)]
352pub enum HookError {
353    ExecutionFailed(String),
354    InvalidState(String),
355}
356
357impl std::fmt::Display for HookError {
358    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
359        match self {
360            HookError::ExecutionFailed(msg) => write!(f, "Hook execution failed: {}", msg),
361            HookError::InvalidState(msg) => write!(f, "Invalid state: {}", msg),
362        }
363    }
364}
365
366impl std::error::Error for HookError {}
367
368/// Hook manager for managing multiple hooks
369pub struct HookManager {
370    hooks: Vec<Box<dyn OptimizationHook>>,
371}
372
373impl HookManager {
374    pub fn new() -> Self {
375        Self { hooks: Vec::new() }
376    }
377
378    pub fn add_hook(&mut self, hook: Box<dyn OptimizationHook>) {
379        self.hooks.push(hook);
380    }
381
382    pub fn trigger_optimization_start(
383        &mut self,
384        config: &PluginConfig,
385        constraints: &ParameterConstraints,
386    ) -> Result<(), HookError> {
387        for hook in &mut self.hooks {
388            hook.on_optimization_start(config, constraints)?;
389        }
390        Ok(())
391    }
392
393    pub fn trigger_iteration_start(
394        &mut self,
395        iteration: usize,
396        history: &OptimizationHistory,
397    ) -> Result<(), HookError> {
398        for hook in &mut self.hooks {
399            hook.on_iteration_start(iteration, history)?;
400        }
401        Ok(())
402    }
403
404    pub fn trigger_evaluation(
405        &mut self,
406        parameters: &HashMap<String, Float>,
407        objective_value: Float,
408        iteration: usize,
409    ) -> Result<(), HookError> {
410        for hook in &mut self.hooks {
411            hook.on_evaluation(parameters, objective_value, iteration)?;
412        }
413        Ok(())
414    }
415
416    pub fn trigger_iteration_end(
417        &mut self,
418        iteration: usize,
419        history: &OptimizationHistory,
420    ) -> Result<(), HookError> {
421        for hook in &mut self.hooks {
422            hook.on_iteration_end(iteration, history)?;
423        }
424        Ok(())
425    }
426
427    pub fn trigger_optimization_end(
428        &mut self,
429        history: &OptimizationHistory,
430        reason: StopReason,
431    ) -> Result<(), HookError> {
432        for hook in &mut self.hooks {
433            hook.on_optimization_end(history, reason.clone())?;
434        }
435        Ok(())
436    }
437
438    pub fn trigger_error(&mut self, error: &dyn std::error::Error) -> Result<(), HookError> {
439        for hook in &mut self.hooks {
440            hook.on_error(error)?;
441        }
442        Ok(())
443    }
444}
445
446impl Default for HookManager {
447    fn default() -> Self {
448        Self::new()
449    }
450}
451
452// ============================================================================
453// Middleware Support
454// ============================================================================
455
456/// Middleware for optimization pipelines
457pub trait OptimizationMiddleware: Send + Sync {
458    /// Process suggestion before returning to optimizer
459    fn process_suggestion(
460        &self,
461        parameters: &mut HashMap<String, Float>,
462        history: &OptimizationHistory,
463    ) -> Result<(), MiddlewareError>;
464
465    /// Process observation before storing
466    fn process_observation(
467        &self,
468        parameters: &HashMap<String, Float>,
469        objective_value: &mut Float,
470        history: &OptimizationHistory,
471    ) -> Result<(), MiddlewareError>;
472
473    /// Middleware name
474    fn name(&self) -> &str;
475}
476
477/// Middleware error type
478#[derive(Debug, Clone)]
479pub enum MiddlewareError {
480    ProcessingFailed(String),
481    ValidationFailed(String),
482}
483
484impl std::fmt::Display for MiddlewareError {
485    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
486        match self {
487            MiddlewareError::ProcessingFailed(msg) => write!(f, "Processing failed: {}", msg),
488            MiddlewareError::ValidationFailed(msg) => write!(f, "Validation failed: {}", msg),
489        }
490    }
491}
492
493impl std::error::Error for MiddlewareError {}
494
495/// Middleware pipeline
496pub struct MiddlewarePipeline {
497    middleware: Vec<Box<dyn OptimizationMiddleware>>,
498}
499
500impl MiddlewarePipeline {
501    pub fn new() -> Self {
502        Self {
503            middleware: Vec::new(),
504        }
505    }
506
507    pub fn add(&mut self, middleware: Box<dyn OptimizationMiddleware>) {
508        self.middleware.push(middleware);
509    }
510
511    pub fn process_suggestion(
512        &self,
513        parameters: &mut HashMap<String, Float>,
514        history: &OptimizationHistory,
515    ) -> Result<(), MiddlewareError> {
516        for m in &self.middleware {
517            m.process_suggestion(parameters, history)?;
518        }
519        Ok(())
520    }
521
522    pub fn process_observation(
523        &self,
524        parameters: &HashMap<String, Float>,
525        objective_value: &mut Float,
526        history: &OptimizationHistory,
527    ) -> Result<(), MiddlewareError> {
528        for m in &self.middleware {
529            m.process_observation(parameters, objective_value, history)?;
530        }
531        Ok(())
532    }
533}
534
535impl Default for MiddlewarePipeline {
536    fn default() -> Self {
537        Self::new()
538    }
539}
540
541// ============================================================================
542// Custom Metric Registration
543// ============================================================================
544
545/// Custom metric trait
546pub trait CustomMetric: Send + Sync {
547    /// Metric name
548    fn name(&self) -> &str;
549
550    /// Compute metric value
551    fn compute(
552        &self,
553        parameters: &HashMap<String, Float>,
554        objective_value: Float,
555        history: &OptimizationHistory,
556    ) -> Result<Float, MetricError>;
557
558    /// Whether higher is better
559    fn higher_is_better(&self) -> bool;
560}
561
562/// Metric error type
563#[derive(Debug, Clone)]
564pub enum MetricError {
565    ComputationFailed(String),
566    InvalidInput(String),
567}
568
569impl std::fmt::Display for MetricError {
570    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
571        match self {
572            MetricError::ComputationFailed(msg) => write!(f, "Computation failed: {}", msg),
573            MetricError::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
574        }
575    }
576}
577
578impl std::error::Error for MetricError {}
579
580/// Metric registry
581pub struct MetricRegistry {
582    metrics: HashMap<String, Box<dyn CustomMetric>>,
583}
584
585impl MetricRegistry {
586    pub fn new() -> Self {
587        Self {
588            metrics: HashMap::new(),
589        }
590    }
591
592    pub fn register(&mut self, metric: Box<dyn CustomMetric>) -> Result<(), MetricError> {
593        let name = metric.name().to_string();
594        if self.metrics.contains_key(&name) {
595            return Err(MetricError::InvalidInput(format!(
596                "Metric '{}' already registered",
597                name
598            )));
599        }
600        self.metrics.insert(name, metric);
601        Ok(())
602    }
603
604    pub fn compute(
605        &self,
606        metric_name: &str,
607        parameters: &HashMap<String, Float>,
608        objective_value: Float,
609        history: &OptimizationHistory,
610    ) -> Result<Float, MetricError> {
611        let metric = self.metrics.get(metric_name).ok_or_else(|| {
612            MetricError::InvalidInput(format!("Metric '{}' not found", metric_name))
613        })?;
614
615        metric.compute(parameters, objective_value, history)
616    }
617
618    pub fn list_metrics(&self) -> Vec<String> {
619        self.metrics.keys().cloned().collect()
620    }
621}
622
623impl Default for MetricRegistry {
624    fn default() -> Self {
625        Self::new()
626    }
627}
628
629// ============================================================================
630// Example Implementations
631// ============================================================================
632
633/// Simple logging hook
634pub struct LoggingHook {
635    log_interval: usize,
636}
637
638impl LoggingHook {
639    pub fn new(log_interval: usize) -> Self {
640        Self { log_interval }
641    }
642}
643
644impl OptimizationHook for LoggingHook {
645    fn on_optimization_start(
646        &mut self,
647        _config: &PluginConfig,
648        _constraints: &ParameterConstraints,
649    ) -> Result<(), HookError> {
650        println!("Optimization started");
651        Ok(())
652    }
653
654    fn on_iteration_start(
655        &mut self,
656        iteration: usize,
657        _history: &OptimizationHistory,
658    ) -> Result<(), HookError> {
659        if iteration % self.log_interval == 0 {
660            println!("Starting iteration {}", iteration);
661        }
662        Ok(())
663    }
664
665    fn on_evaluation(
666        &mut self,
667        _parameters: &HashMap<String, Float>,
668        _objective_value: Float,
669        _iteration: usize,
670    ) -> Result<(), HookError> {
671        Ok(())
672    }
673
674    fn on_iteration_end(
675        &mut self,
676        iteration: usize,
677        history: &OptimizationHistory,
678    ) -> Result<(), HookError> {
679        if iteration % self.log_interval == 0 {
680            println!(
681                "Iteration {} complete. Best so far: {}",
682                iteration, history.best_value
683            );
684        }
685        Ok(())
686    }
687
688    fn on_optimization_end(
689        &mut self,
690        history: &OptimizationHistory,
691        reason: StopReason,
692    ) -> Result<(), HookError> {
693        println!("Optimization ended. Reason: {:?}", reason);
694        println!("Best value: {}", history.best_value);
695        println!("Total evaluations: {}", history.n_evaluations);
696        Ok(())
697    }
698
699    fn on_error(&mut self, error: &dyn std::error::Error) -> Result<(), HookError> {
700        eprintln!("Error during optimization: {}", error);
701        Ok(())
702    }
703}
704
705/// Parameter normalization middleware
706pub struct NormalizationMiddleware {
707    name: String,
708}
709
710impl NormalizationMiddleware {
711    pub fn new() -> Self {
712        Self {
713            name: "normalization".to_string(),
714        }
715    }
716}
717
718impl Default for NormalizationMiddleware {
719    fn default() -> Self {
720        Self::new()
721    }
722}
723
724impl OptimizationMiddleware for NormalizationMiddleware {
725    fn process_suggestion(
726        &self,
727        parameters: &mut HashMap<String, Float>,
728        _history: &OptimizationHistory,
729    ) -> Result<(), MiddlewareError> {
730        // Example: ensure all parameters are within valid ranges
731        for (_, value) in parameters.iter_mut() {
732            if value.is_nan() || value.is_infinite() {
733                *value = 0.0; // Default value for invalid parameters
734            }
735        }
736        Ok(())
737    }
738
739    fn process_observation(
740        &self,
741        _parameters: &HashMap<String, Float>,
742        objective_value: &mut Float,
743        _history: &OptimizationHistory,
744    ) -> Result<(), MiddlewareError> {
745        // Example: handle invalid objective values
746        if objective_value.is_nan() || objective_value.is_infinite() {
747            *objective_value = f64::NEG_INFINITY;
748        }
749        Ok(())
750    }
751
752    fn name(&self) -> &str {
753        &self.name
754    }
755}
756
757// ============================================================================
758// Tests
759// ============================================================================
760
761#[cfg(test)]
762mod tests {
763    use super::*;
764
765    #[test]
766    fn test_plugin_config() {
767        let config = PluginConfig::default();
768        assert_eq!(config.max_iterations, 100);
769        assert!(!config.parallel);
770    }
771
772    #[test]
773    fn test_optimization_history() {
774        let mut history = OptimizationHistory::new();
775        assert_eq!(history.n_evaluations, 0);
776
777        let mut params = HashMap::new();
778        params.insert("lr".to_string(), 0.01);
779        history.add_evaluation(params, 0.95);
780
781        assert_eq!(history.n_evaluations, 1);
782        assert_eq!(history.best_value, 0.95);
783    }
784
785    #[test]
786    fn test_plugin_registry() {
787        let registry = PluginRegistry::new();
788        assert!(registry.list_plugins().is_ok());
789    }
790
791    #[test]
792    fn test_hook_manager() {
793        let mut manager = HookManager::new();
794        let hook = Box::new(LoggingHook::new(10));
795        manager.add_hook(hook);
796
797        let config = PluginConfig::default();
798        let constraints = ParameterConstraints::default();
799        assert!(manager
800            .trigger_optimization_start(&config, &constraints)
801            .is_ok());
802    }
803
804    #[test]
805    fn test_middleware_pipeline() {
806        let mut pipeline = MiddlewarePipeline::new();
807        pipeline.add(Box::new(NormalizationMiddleware::new()));
808
809        let mut params = HashMap::new();
810        params.insert("x".to_string(), f64::NAN);
811
812        let history = OptimizationHistory::new();
813        assert!(pipeline.process_suggestion(&mut params, &history).is_ok());
814        assert_eq!(params.get("x"), Some(&0.0));
815    }
816
817    #[test]
818    fn test_metric_registry() {
819        let registry = MetricRegistry::new();
820        assert!(registry.list_metrics().is_empty());
821    }
822}