sklears_kernel_approximation/
plugin_architecture.rs

1//! Plugin architecture for custom kernel approximations
2//!
3//! This module provides a flexible plugin system for registering and using
4//! custom kernel approximation methods. It allows runtime discovery and
5//! instantiation of kernel approximation plugins.
6
7use scirs2_core::ndarray::Array2;
8use serde::{Deserialize, Serialize};
9use sklears_core::error::SklearsError;
10use sklears_core::traits::{Fit, Transform};
11use std::any::Any;
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use thiserror::Error;
15
16/// Errors that can occur in the plugin system
17#[derive(Error, Debug)]
18/// PluginError
19pub enum PluginError {
20    #[error("Plugin not found: {name}")]
21    PluginNotFound { name: String },
22    #[error("Plugin already registered: {name}")]
23    PluginAlreadyRegistered { name: String },
24    #[error("Invalid plugin configuration: {message}")]
25    InvalidConfiguration { message: String },
26    #[error("Plugin initialization failed: {message}")]
27    InitializationFailed { message: String },
28    #[error("Type casting error for plugin: {name}")]
29    TypeCastError { name: String },
30}
31
32/// Metadata about a plugin
33#[derive(Debug, Clone, Serialize, Deserialize)]
34/// PluginMetadata
35pub struct PluginMetadata {
36    /// Plugin name
37    pub name: String,
38    /// Plugin version
39    pub version: String,
40    /// Plugin description
41    pub description: String,
42    /// Plugin author
43    pub author: String,
44    /// Supported kernel types
45    pub supported_kernels: Vec<String>,
46    /// Required parameters
47    pub required_parameters: Vec<String>,
48    /// Optional parameters
49    pub optional_parameters: Vec<String>,
50}
51
52/// Configuration for a plugin
53#[derive(Debug, Clone, Serialize, Deserialize)]
54/// PluginConfig
55#[derive(Default)]
56pub struct PluginConfig {
57    /// Parameters for the plugin
58    pub parameters: HashMap<String, serde_json::Value>,
59    /// Random seed for reproducibility
60    pub random_state: Option<u64>,
61}
62
63/// Trait for kernel approximation plugins
64pub trait KernelApproximationPlugin: Send + Sync {
65    /// Get plugin metadata
66    fn metadata(&self) -> PluginMetadata;
67
68    /// Create a new instance of the plugin with given configuration
69    fn create(
70        &self,
71        config: PluginConfig,
72    ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError>;
73
74    /// Validate configuration
75    fn validate_config(&self, config: &PluginConfig) -> std::result::Result<(), PluginError>;
76
77    /// Get default configuration
78    fn default_config(&self) -> PluginConfig;
79}
80
81/// Instance of a kernel approximation plugin
82pub trait KernelApproximationInstance: Send + Sync {
83    /// Fit the approximation to data
84    fn fit(&mut self, x: &Array2<f64>, y: &()) -> std::result::Result<(), PluginError>;
85
86    /// Transform data using the fitted approximation
87    fn transform(&self, x: &Array2<f64>) -> std::result::Result<Array2<f64>, PluginError>;
88
89    /// Check if the instance is fitted
90    fn is_fitted(&self) -> bool;
91
92    /// Get the number of output features
93    fn n_output_features(&self) -> Option<usize>;
94
95    /// Clone the instance
96    fn clone_instance(&self) -> Box<dyn KernelApproximationInstance>;
97
98    /// Get instance as Any for downcasting
99    fn as_any(&self) -> &dyn Any;
100}
101
102/// Plugin factory for creating instances
103pub struct PluginFactory {
104    plugins: Arc<RwLock<HashMap<String, Box<dyn KernelApproximationPlugin>>>>,
105}
106
107impl Default for PluginFactory {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112
113impl PluginFactory {
114    /// Create a new plugin factory
115    pub fn new() -> Self {
116        Self {
117            plugins: Arc::new(RwLock::new(HashMap::new())),
118        }
119    }
120
121    /// Register a plugin
122    pub fn register_plugin(
123        &self,
124        plugin: Box<dyn KernelApproximationPlugin>,
125    ) -> std::result::Result<(), PluginError> {
126        let metadata = plugin.metadata();
127        let mut plugins = self.plugins.write().unwrap();
128
129        if plugins.contains_key(&metadata.name) {
130            return Err(PluginError::PluginAlreadyRegistered {
131                name: metadata.name,
132            });
133        }
134
135        plugins.insert(metadata.name.clone(), plugin);
136        Ok(())
137    }
138
139    /// Unregister a plugin
140    pub fn unregister_plugin(&self, name: &str) -> std::result::Result<(), PluginError> {
141        let mut plugins = self.plugins.write().unwrap();
142        plugins
143            .remove(name)
144            .ok_or_else(|| PluginError::PluginNotFound {
145                name: name.to_string(),
146            })?;
147        Ok(())
148    }
149
150    /// Get plugin metadata
151    pub fn get_plugin_metadata(
152        &self,
153        name: &str,
154    ) -> std::result::Result<PluginMetadata, PluginError> {
155        let plugins = self.plugins.read().unwrap();
156        let plugin = plugins
157            .get(name)
158            .ok_or_else(|| PluginError::PluginNotFound {
159                name: name.to_string(),
160            })?;
161        Ok(plugin.metadata())
162    }
163
164    /// List all registered plugins
165    pub fn list_plugins(&self) -> Vec<PluginMetadata> {
166        let plugins = self.plugins.read().unwrap();
167        plugins.values().map(|p| p.metadata()).collect()
168    }
169
170    /// Create an instance of a plugin
171    pub fn create_instance(
172        &self,
173        name: &str,
174        config: PluginConfig,
175    ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError> {
176        let plugins = self.plugins.read().unwrap();
177        let plugin = plugins
178            .get(name)
179            .ok_or_else(|| PluginError::PluginNotFound {
180                name: name.to_string(),
181            })?;
182
183        plugin.validate_config(&config)?;
184        plugin.create(config)
185    }
186
187    /// Get default configuration for a plugin
188    pub fn get_default_config(&self, name: &str) -> std::result::Result<PluginConfig, PluginError> {
189        let plugins = self.plugins.read().unwrap();
190        let plugin = plugins
191            .get(name)
192            .ok_or_else(|| PluginError::PluginNotFound {
193                name: name.to_string(),
194            })?;
195        Ok(plugin.default_config())
196    }
197}
198
199/// Wrapper to make plugin instances compatible with sklears traits
200pub struct PluginWrapper {
201    instance: Box<dyn KernelApproximationInstance>,
202    metadata: PluginMetadata,
203}
204
205impl PluginWrapper {
206    /// Create a new plugin wrapper
207    pub fn new(instance: Box<dyn KernelApproximationInstance>, metadata: PluginMetadata) -> Self {
208        Self { instance, metadata }
209    }
210
211    /// Get plugin metadata
212    pub fn metadata(&self) -> &PluginMetadata {
213        &self.metadata
214    }
215
216    /// Get the underlying instance
217    pub fn instance(&self) -> &dyn KernelApproximationInstance {
218        self.instance.as_ref()
219    }
220
221    /// Get mutable access to the underlying instance
222    pub fn instance_mut(&mut self) -> &mut dyn KernelApproximationInstance {
223        self.instance.as_mut()
224    }
225}
226
227impl Clone for PluginWrapper {
228    fn clone(&self) -> Self {
229        Self {
230            instance: self.instance.clone_instance(),
231            metadata: self.metadata.clone(),
232        }
233    }
234}
235
236impl Fit<Array2<f64>, ()> for PluginWrapper {
237    type Fitted = FittedPluginWrapper;
238
239    fn fit(mut self, x: &Array2<f64>, y: &()) -> Result<Self::Fitted, SklearsError> {
240        self.instance
241            .fit(x, y)
242            .map_err(|e| SklearsError::InvalidInput(format!("{}", e)))?;
243        Ok(FittedPluginWrapper {
244            instance: self.instance,
245            metadata: self.metadata,
246        })
247    }
248}
249
250/// Fitted plugin wrapper
251pub struct FittedPluginWrapper {
252    instance: Box<dyn KernelApproximationInstance>,
253    metadata: PluginMetadata,
254}
255
256impl FittedPluginWrapper {
257    /// Get plugin metadata
258    pub fn metadata(&self) -> &PluginMetadata {
259        &self.metadata
260    }
261
262    /// Get the underlying instance
263    pub fn instance(&self) -> &dyn KernelApproximationInstance {
264        self.instance.as_ref()
265    }
266
267    /// Get the number of output features
268    pub fn n_output_features(&self) -> Option<usize> {
269        self.instance.n_output_features()
270    }
271}
272
273impl Clone for FittedPluginWrapper {
274    fn clone(&self) -> Self {
275        Self {
276            instance: self.instance.clone_instance(),
277            metadata: self.metadata.clone(),
278        }
279    }
280}
281
282impl Transform<Array2<f64>, Array2<f64>> for FittedPluginWrapper {
283    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
284        self.instance
285            .transform(x)
286            .map_err(|e| SklearsError::InvalidInput(format!("{}", e)))
287    }
288}
289
290/// Global plugin registry
291static GLOBAL_FACTORY: std::sync::LazyLock<PluginFactory> =
292    std::sync::LazyLock::new(PluginFactory::new);
293
294/// Register a plugin globally
295pub fn register_global_plugin(
296    plugin: Box<dyn KernelApproximationPlugin>,
297) -> std::result::Result<(), PluginError> {
298    GLOBAL_FACTORY.register_plugin(plugin)
299}
300
301/// Create an instance from the global registry
302pub fn create_global_plugin_instance(
303    name: &str,
304    config: PluginConfig,
305) -> std::result::Result<PluginWrapper, PluginError> {
306    let instance = GLOBAL_FACTORY.create_instance(name, config)?;
307    let metadata = GLOBAL_FACTORY.get_plugin_metadata(name)?;
308    Ok(PluginWrapper::new(instance, metadata))
309}
310
311/// List all globally registered plugins
312pub fn list_global_plugins() -> Vec<PluginMetadata> {
313    GLOBAL_FACTORY.list_plugins()
314}
315
316/// Example plugin implementing a simple linear kernel approximation
317pub struct LinearKernelPlugin;
318
319impl KernelApproximationPlugin for LinearKernelPlugin {
320    fn metadata(&self) -> PluginMetadata {
321        PluginMetadata {
322            name: "linear_kernel".to_string(),
323            version: "1.0.0".to_string(),
324            description: "Simple linear kernel approximation plugin".to_string(),
325            author: "sklears".to_string(),
326            supported_kernels: vec!["linear".to_string()],
327            required_parameters: vec!["n_components".to_string()],
328            optional_parameters: vec!["normalize".to_string()],
329        }
330    }
331
332    fn create(
333        &self,
334        config: PluginConfig,
335    ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError> {
336        let n_components = config
337            .parameters
338            .get("n_components")
339            .and_then(|v| v.as_u64())
340            .ok_or_else(|| PluginError::InvalidConfiguration {
341                message: "n_components parameter required".to_string(),
342            })? as usize;
343
344        let normalize = config
345            .parameters
346            .get("normalize")
347            .and_then(|v| v.as_bool())
348            .unwrap_or(false);
349
350        Ok(Box::new(LinearKernelInstance {
351            n_components,
352            normalize,
353            projection_matrix: None,
354        }))
355    }
356
357    fn validate_config(&self, config: &PluginConfig) -> std::result::Result<(), PluginError> {
358        if !config.parameters.contains_key("n_components") {
359            return Err(PluginError::InvalidConfiguration {
360                message: "n_components parameter is required".to_string(),
361            });
362        }
363
364        if let Some(n_comp) = config
365            .parameters
366            .get("n_components")
367            .and_then(|v| v.as_u64())
368        {
369            if n_comp == 0 {
370                return Err(PluginError::InvalidConfiguration {
371                    message: "n_components must be greater than 0".to_string(),
372                });
373            }
374        } else {
375            return Err(PluginError::InvalidConfiguration {
376                message: "n_components must be a positive integer".to_string(),
377            });
378        }
379
380        Ok(())
381    }
382
383    fn default_config(&self) -> PluginConfig {
384        let mut config = PluginConfig::default();
385        config.parameters.insert(
386            "n_components".to_string(),
387            serde_json::Value::Number(100.into()),
388        );
389        config
390            .parameters
391            .insert("normalize".to_string(), serde_json::Value::Bool(false));
392        config
393    }
394}
395
396/// Linear kernel instance
397pub struct LinearKernelInstance {
398    n_components: usize,
399    normalize: bool,
400    projection_matrix: Option<Array2<f64>>,
401}
402
403impl KernelApproximationInstance for LinearKernelInstance {
404    fn fit(&mut self, x: &Array2<f64>, _y: &()) -> std::result::Result<(), PluginError> {
405        use scirs2_core::random::thread_rng;
406        use scirs2_core::random::StandardNormal;
407
408        let (_, n_features) = x.dim();
409        let mut rng = thread_rng();
410
411        // Create random projection matrix
412        let mut proj_matrix = Array2::zeros((n_features, self.n_components));
413        for elem in proj_matrix.iter_mut() {
414            *elem = rng.sample(StandardNormal);
415        }
416
417        if self.normalize {
418            // Normalize columns to unit length
419            for j in 0..self.n_components {
420                let mut col = proj_matrix.column_mut(j);
421                let norm = col.mapv(|x: f64| x * x).sum().sqrt();
422                if norm > 1e-8 {
423                    col /= norm;
424                }
425            }
426        }
427
428        self.projection_matrix = Some(proj_matrix);
429        Ok(())
430    }
431
432    fn transform(&self, x: &Array2<f64>) -> std::result::Result<Array2<f64>, PluginError> {
433        let proj_matrix =
434            self.projection_matrix
435                .as_ref()
436                .ok_or_else(|| PluginError::InitializationFailed {
437                    message: "Plugin not fitted".to_string(),
438                })?;
439
440        Ok(x.dot(proj_matrix))
441    }
442
443    fn is_fitted(&self) -> bool {
444        self.projection_matrix.is_some()
445    }
446
447    fn n_output_features(&self) -> Option<usize> {
448        if self.is_fitted() {
449            Some(self.n_components)
450        } else {
451            None
452        }
453    }
454
455    fn clone_instance(&self) -> Box<dyn KernelApproximationInstance> {
456        Box::new(LinearKernelInstance {
457            n_components: self.n_components,
458            normalize: self.normalize,
459            projection_matrix: self.projection_matrix.clone(),
460        })
461    }
462
463    fn as_any(&self) -> &dyn Any {
464        self
465    }
466}
467
468#[allow(non_snake_case)]
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use scirs2_core::ndarray::array;
473
474    #[test]
475    fn test_plugin_registration() {
476        let factory = PluginFactory::new();
477        let plugin = Box::new(LinearKernelPlugin);
478
479        assert!(factory.register_plugin(plugin).is_ok());
480
481        let plugins = factory.list_plugins();
482        assert_eq!(plugins.len(), 1);
483        assert_eq!(plugins[0].name, "linear_kernel");
484    }
485
486    #[test]
487    fn test_plugin_instance_creation() {
488        let factory = PluginFactory::new();
489        let plugin = Box::new(LinearKernelPlugin);
490        factory.register_plugin(plugin).unwrap();
491
492        let mut config = PluginConfig::default();
493        config.parameters.insert(
494            "n_components".to_string(),
495            serde_json::Value::Number(50.into()),
496        );
497
498        let instance = factory.create_instance("linear_kernel", config);
499        assert!(instance.is_ok());
500    }
501
502    #[test]
503    fn test_plugin_wrapper_fit_transform() {
504        let factory = PluginFactory::new();
505        let plugin = Box::new(LinearKernelPlugin);
506        factory.register_plugin(plugin).unwrap();
507
508        let mut config = PluginConfig::default();
509        config.parameters.insert(
510            "n_components".to_string(),
511            serde_json::Value::Number(30.into()),
512        );
513
514        let instance = factory.create_instance("linear_kernel", config).unwrap();
515        let metadata = factory.get_plugin_metadata("linear_kernel").unwrap();
516        let wrapper = PluginWrapper::new(instance, metadata);
517
518        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
519        let fitted = wrapper.fit(&x, &()).unwrap();
520        let transformed = fitted.transform(&x).unwrap();
521
522        assert_eq!(transformed.shape(), &[3, 30]);
523    }
524
525    #[test]
526    fn test_global_plugin_registry() {
527        let plugin = Box::new(LinearKernelPlugin);
528        assert!(register_global_plugin(plugin).is_ok());
529
530        let plugins = list_global_plugins();
531        assert!(!plugins.is_empty());
532    }
533
534    #[test]
535    fn test_invalid_configuration() {
536        let factory = PluginFactory::new();
537        let plugin = Box::new(LinearKernelPlugin);
538        factory.register_plugin(plugin).unwrap();
539
540        let config = PluginConfig::default(); // Missing n_components
541        let result = factory.create_instance("linear_kernel", config);
542        assert!(result.is_err());
543    }
544}