sklears_kernel_approximation/
plugin_architecture.rs

1//! Plugin architecture for custom kernel approximations
2//!
3//! This module provides a flexible plugin system for registering and using
4//! custom kernel approximation methods. It allows runtime discovery and
5//! instantiation of kernel approximation plugins.
6
7use scirs2_core::ndarray::Array2;
8use serde::{Deserialize, Serialize};
9use sklears_core::error::SklearsError;
10use sklears_core::traits::{Fit, Transform};
11use std::any::Any;
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use thiserror::Error;
15
16/// Errors that can occur in the plugin system
17#[derive(Error, Debug)]
18/// PluginError
19pub enum PluginError {
20    #[error("Plugin not found: {name}")]
21    PluginNotFound { name: String },
22    #[error("Plugin already registered: {name}")]
23    PluginAlreadyRegistered { name: String },
24    #[error("Invalid plugin configuration: {message}")]
25    InvalidConfiguration { message: String },
26    #[error("Plugin initialization failed: {message}")]
27    InitializationFailed { message: String },
28    #[error("Type casting error for plugin: {name}")]
29    TypeCastError { name: String },
30}
31
32/// Metadata about a plugin
33#[derive(Debug, Clone, Serialize, Deserialize)]
34/// PluginMetadata
35pub struct PluginMetadata {
36    /// Plugin name
37    pub name: String,
38    /// Plugin version
39    pub version: String,
40    /// Plugin description
41    pub description: String,
42    /// Plugin author
43    pub author: String,
44    /// Supported kernel types
45    pub supported_kernels: Vec<String>,
46    /// Required parameters
47    pub required_parameters: Vec<String>,
48    /// Optional parameters
49    pub optional_parameters: Vec<String>,
50}
51
52/// Configuration for a plugin
53#[derive(Debug, Clone, Serialize, Deserialize)]
54/// PluginConfig
55pub struct PluginConfig {
56    /// Parameters for the plugin
57    pub parameters: HashMap<String, serde_json::Value>,
58    /// Random seed for reproducibility
59    pub random_state: Option<u64>,
60}
61
62impl Default for PluginConfig {
63    fn default() -> Self {
64        Self {
65            parameters: HashMap::new(),
66            random_state: None,
67        }
68    }
69}
70
71/// Trait for kernel approximation plugins
72pub trait KernelApproximationPlugin: Send + Sync {
73    /// Get plugin metadata
74    fn metadata(&self) -> PluginMetadata;
75
76    /// Create a new instance of the plugin with given configuration
77    fn create(
78        &self,
79        config: PluginConfig,
80    ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError>;
81
82    /// Validate configuration
83    fn validate_config(&self, config: &PluginConfig) -> std::result::Result<(), PluginError>;
84
85    /// Get default configuration
86    fn default_config(&self) -> PluginConfig;
87}
88
89/// Instance of a kernel approximation plugin
90pub trait KernelApproximationInstance: Send + Sync {
91    /// Fit the approximation to data
92    fn fit(&mut self, x: &Array2<f64>, y: &()) -> std::result::Result<(), PluginError>;
93
94    /// Transform data using the fitted approximation
95    fn transform(&self, x: &Array2<f64>) -> std::result::Result<Array2<f64>, PluginError>;
96
97    /// Check if the instance is fitted
98    fn is_fitted(&self) -> bool;
99
100    /// Get the number of output features
101    fn n_output_features(&self) -> Option<usize>;
102
103    /// Clone the instance
104    fn clone_instance(&self) -> Box<dyn KernelApproximationInstance>;
105
106    /// Get instance as Any for downcasting
107    fn as_any(&self) -> &dyn Any;
108}
109
110/// Plugin factory for creating instances
111pub struct PluginFactory {
112    plugins: Arc<RwLock<HashMap<String, Box<dyn KernelApproximationPlugin>>>>,
113}
114
115impl Default for PluginFactory {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121impl PluginFactory {
122    /// Create a new plugin factory
123    pub fn new() -> Self {
124        Self {
125            plugins: Arc::new(RwLock::new(HashMap::new())),
126        }
127    }
128
129    /// Register a plugin
130    pub fn register_plugin(
131        &self,
132        plugin: Box<dyn KernelApproximationPlugin>,
133    ) -> std::result::Result<(), PluginError> {
134        let metadata = plugin.metadata();
135        let mut plugins = self.plugins.write().unwrap();
136
137        if plugins.contains_key(&metadata.name) {
138            return Err(PluginError::PluginAlreadyRegistered {
139                name: metadata.name,
140            });
141        }
142
143        plugins.insert(metadata.name.clone(), plugin);
144        Ok(())
145    }
146
147    /// Unregister a plugin
148    pub fn unregister_plugin(&self, name: &str) -> std::result::Result<(), PluginError> {
149        let mut plugins = self.plugins.write().unwrap();
150        plugins
151            .remove(name)
152            .ok_or_else(|| PluginError::PluginNotFound {
153                name: name.to_string(),
154            })?;
155        Ok(())
156    }
157
158    /// Get plugin metadata
159    pub fn get_plugin_metadata(
160        &self,
161        name: &str,
162    ) -> std::result::Result<PluginMetadata, PluginError> {
163        let plugins = self.plugins.read().unwrap();
164        let plugin = plugins
165            .get(name)
166            .ok_or_else(|| PluginError::PluginNotFound {
167                name: name.to_string(),
168            })?;
169        Ok(plugin.metadata())
170    }
171
172    /// List all registered plugins
173    pub fn list_plugins(&self) -> Vec<PluginMetadata> {
174        let plugins = self.plugins.read().unwrap();
175        plugins.values().map(|p| p.metadata()).collect()
176    }
177
178    /// Create an instance of a plugin
179    pub fn create_instance(
180        &self,
181        name: &str,
182        config: PluginConfig,
183    ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError> {
184        let plugins = self.plugins.read().unwrap();
185        let plugin = plugins
186            .get(name)
187            .ok_or_else(|| PluginError::PluginNotFound {
188                name: name.to_string(),
189            })?;
190
191        plugin.validate_config(&config)?;
192        plugin.create(config)
193    }
194
195    /// Get default configuration for a plugin
196    pub fn get_default_config(&self, name: &str) -> std::result::Result<PluginConfig, PluginError> {
197        let plugins = self.plugins.read().unwrap();
198        let plugin = plugins
199            .get(name)
200            .ok_or_else(|| PluginError::PluginNotFound {
201                name: name.to_string(),
202            })?;
203        Ok(plugin.default_config())
204    }
205}
206
207/// Wrapper to make plugin instances compatible with sklears traits
208pub struct PluginWrapper {
209    instance: Box<dyn KernelApproximationInstance>,
210    metadata: PluginMetadata,
211}
212
213impl PluginWrapper {
214    /// Create a new plugin wrapper
215    pub fn new(instance: Box<dyn KernelApproximationInstance>, metadata: PluginMetadata) -> Self {
216        Self { instance, metadata }
217    }
218
219    /// Get plugin metadata
220    pub fn metadata(&self) -> &PluginMetadata {
221        &self.metadata
222    }
223
224    /// Get the underlying instance
225    pub fn instance(&self) -> &dyn KernelApproximationInstance {
226        self.instance.as_ref()
227    }
228
229    /// Get mutable access to the underlying instance
230    pub fn instance_mut(&mut self) -> &mut dyn KernelApproximationInstance {
231        self.instance.as_mut()
232    }
233}
234
235impl Clone for PluginWrapper {
236    fn clone(&self) -> Self {
237        Self {
238            instance: self.instance.clone_instance(),
239            metadata: self.metadata.clone(),
240        }
241    }
242}
243
244impl Fit<Array2<f64>, ()> for PluginWrapper {
245    type Fitted = FittedPluginWrapper;
246
247    fn fit(mut self, x: &Array2<f64>, y: &()) -> Result<Self::Fitted, SklearsError> {
248        self.instance
249            .fit(x, y)
250            .map_err(|e| SklearsError::InvalidInput(format!("{}", e)))?;
251        Ok(FittedPluginWrapper {
252            instance: self.instance,
253            metadata: self.metadata,
254        })
255    }
256}
257
258/// Fitted plugin wrapper
259pub struct FittedPluginWrapper {
260    instance: Box<dyn KernelApproximationInstance>,
261    metadata: PluginMetadata,
262}
263
264impl FittedPluginWrapper {
265    /// Get plugin metadata
266    pub fn metadata(&self) -> &PluginMetadata {
267        &self.metadata
268    }
269
270    /// Get the underlying instance
271    pub fn instance(&self) -> &dyn KernelApproximationInstance {
272        self.instance.as_ref()
273    }
274
275    /// Get the number of output features
276    pub fn n_output_features(&self) -> Option<usize> {
277        self.instance.n_output_features()
278    }
279}
280
281impl Clone for FittedPluginWrapper {
282    fn clone(&self) -> Self {
283        Self {
284            instance: self.instance.clone_instance(),
285            metadata: self.metadata.clone(),
286        }
287    }
288}
289
290impl Transform<Array2<f64>, Array2<f64>> for FittedPluginWrapper {
291    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
292        self.instance
293            .transform(x)
294            .map_err(|e| SklearsError::InvalidInput(format!("{}", e)))
295    }
296}
297
298/// Global plugin registry
299static GLOBAL_FACTORY: std::sync::LazyLock<PluginFactory> =
300    std::sync::LazyLock::new(PluginFactory::new);
301
302/// Register a plugin globally
303pub fn register_global_plugin(
304    plugin: Box<dyn KernelApproximationPlugin>,
305) -> std::result::Result<(), PluginError> {
306    GLOBAL_FACTORY.register_plugin(plugin)
307}
308
309/// Create an instance from the global registry
310pub fn create_global_plugin_instance(
311    name: &str,
312    config: PluginConfig,
313) -> std::result::Result<PluginWrapper, PluginError> {
314    let instance = GLOBAL_FACTORY.create_instance(name, config)?;
315    let metadata = GLOBAL_FACTORY.get_plugin_metadata(name)?;
316    Ok(PluginWrapper::new(instance, metadata))
317}
318
319/// List all globally registered plugins
320pub fn list_global_plugins() -> Vec<PluginMetadata> {
321    GLOBAL_FACTORY.list_plugins()
322}
323
324/// Example plugin implementing a simple linear kernel approximation
325pub struct LinearKernelPlugin;
326
327impl KernelApproximationPlugin for LinearKernelPlugin {
328    fn metadata(&self) -> PluginMetadata {
329        /// PluginMetadata
330        PluginMetadata {
331            name: "linear_kernel".to_string(),
332            version: "1.0.0".to_string(),
333            description: "Simple linear kernel approximation plugin".to_string(),
334            author: "sklears".to_string(),
335            supported_kernels: vec!["linear".to_string()],
336            required_parameters: vec!["n_components".to_string()],
337            optional_parameters: vec!["normalize".to_string()],
338        }
339    }
340
341    fn create(
342        &self,
343        config: PluginConfig,
344    ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError> {
345        let n_components = config
346            .parameters
347            .get("n_components")
348            .and_then(|v| v.as_u64())
349            .ok_or_else(|| PluginError::InvalidConfiguration {
350                message: "n_components parameter required".to_string(),
351            })? as usize;
352
353        let normalize = config
354            .parameters
355            .get("normalize")
356            .and_then(|v| v.as_bool())
357            .unwrap_or(false);
358
359        Ok(Box::new(LinearKernelInstance {
360            n_components,
361            normalize,
362            projection_matrix: None,
363        }))
364    }
365
366    fn validate_config(&self, config: &PluginConfig) -> std::result::Result<(), PluginError> {
367        if !config.parameters.contains_key("n_components") {
368            return Err(PluginError::InvalidConfiguration {
369                message: "n_components parameter is required".to_string(),
370            });
371        }
372
373        if let Some(n_comp) = config
374            .parameters
375            .get("n_components")
376            .and_then(|v| v.as_u64())
377        {
378            if n_comp == 0 {
379                return Err(PluginError::InvalidConfiguration {
380                    message: "n_components must be greater than 0".to_string(),
381                });
382            }
383        } else {
384            return Err(PluginError::InvalidConfiguration {
385                message: "n_components must be a positive integer".to_string(),
386            });
387        }
388
389        Ok(())
390    }
391
392    fn default_config(&self) -> PluginConfig {
393        let mut config = PluginConfig::default();
394        config.parameters.insert(
395            "n_components".to_string(),
396            serde_json::Value::Number(100.into()),
397        );
398        config
399            .parameters
400            .insert("normalize".to_string(), serde_json::Value::Bool(false));
401        config
402    }
403}
404
405/// Linear kernel instance
406pub struct LinearKernelInstance {
407    n_components: usize,
408    normalize: bool,
409    projection_matrix: Option<Array2<f64>>,
410}
411
412impl KernelApproximationInstance for LinearKernelInstance {
413    fn fit(&mut self, x: &Array2<f64>, _y: &()) -> std::result::Result<(), PluginError> {
414        use scirs2_core::random::thread_rng;
415        use scirs2_core::random::{Distribution, StandardNormal};
416
417        let (_, n_features) = x.dim();
418        let mut rng = thread_rng();
419
420        // Create random projection matrix
421        let mut proj_matrix = Array2::zeros((n_features, self.n_components));
422        for elem in proj_matrix.iter_mut() {
423            *elem = rng.sample(StandardNormal);
424        }
425
426        if self.normalize {
427            // Normalize columns to unit length
428            for j in 0..self.n_components {
429                let mut col = proj_matrix.column_mut(j);
430                let norm = col.mapv(|x: f64| x * x).sum().sqrt();
431                if norm > 1e-8 {
432                    col /= norm;
433                }
434            }
435        }
436
437        self.projection_matrix = Some(proj_matrix);
438        Ok(())
439    }
440
441    fn transform(&self, x: &Array2<f64>) -> std::result::Result<Array2<f64>, PluginError> {
442        let proj_matrix =
443            self.projection_matrix
444                .as_ref()
445                .ok_or_else(|| PluginError::InitializationFailed {
446                    message: "Plugin not fitted".to_string(),
447                })?;
448
449        Ok(x.dot(proj_matrix))
450    }
451
452    fn is_fitted(&self) -> bool {
453        self.projection_matrix.is_some()
454    }
455
456    fn n_output_features(&self) -> Option<usize> {
457        if self.is_fitted() {
458            Some(self.n_components)
459        } else {
460            None
461        }
462    }
463
464    fn clone_instance(&self) -> Box<dyn KernelApproximationInstance> {
465        Box::new(LinearKernelInstance {
466            n_components: self.n_components,
467            normalize: self.normalize,
468            projection_matrix: self.projection_matrix.clone(),
469        })
470    }
471
472    fn as_any(&self) -> &dyn Any {
473        self
474    }
475}
476
477#[allow(non_snake_case)]
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use scirs2_core::ndarray::array;
482
483    #[test]
484    fn test_plugin_registration() {
485        let factory = PluginFactory::new();
486        let plugin = Box::new(LinearKernelPlugin);
487
488        assert!(factory.register_plugin(plugin).is_ok());
489
490        let plugins = factory.list_plugins();
491        assert_eq!(plugins.len(), 1);
492        assert_eq!(plugins[0].name, "linear_kernel");
493    }
494
495    #[test]
496    fn test_plugin_instance_creation() {
497        let factory = PluginFactory::new();
498        let plugin = Box::new(LinearKernelPlugin);
499        factory.register_plugin(plugin).unwrap();
500
501        let mut config = PluginConfig::default();
502        config.parameters.insert(
503            "n_components".to_string(),
504            serde_json::Value::Number(50.into()),
505        );
506
507        let instance = factory.create_instance("linear_kernel", config);
508        assert!(instance.is_ok());
509    }
510
511    #[test]
512    fn test_plugin_wrapper_fit_transform() {
513        let factory = PluginFactory::new();
514        let plugin = Box::new(LinearKernelPlugin);
515        factory.register_plugin(plugin).unwrap();
516
517        let mut config = PluginConfig::default();
518        config.parameters.insert(
519            "n_components".to_string(),
520            serde_json::Value::Number(30.into()),
521        );
522
523        let instance = factory.create_instance("linear_kernel", config).unwrap();
524        let metadata = factory.get_plugin_metadata("linear_kernel").unwrap();
525        let wrapper = PluginWrapper::new(instance, metadata);
526
527        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
528        let fitted = wrapper.fit(&x, &()).unwrap();
529        let transformed = fitted.transform(&x).unwrap();
530
531        assert_eq!(transformed.shape(), &[3, 30]);
532    }
533
534    #[test]
535    fn test_global_plugin_registry() {
536        let plugin = Box::new(LinearKernelPlugin);
537        assert!(register_global_plugin(plugin).is_ok());
538
539        let plugins = list_global_plugins();
540        assert!(!plugins.is_empty());
541    }
542
543    #[test]
544    fn test_invalid_configuration() {
545        let factory = PluginFactory::new();
546        let plugin = Box::new(LinearKernelPlugin);
547        factory.register_plugin(plugin).unwrap();
548
549        let config = PluginConfig::default(); // Missing n_components
550        let result = factory.create_instance("linear_kernel", config);
551        assert!(result.is_err());
552    }
553}