sklears_kernel_approximation/
plugin_architecture.rs1use scirs2_core::ndarray::Array2;
8use serde::{Deserialize, Serialize};
9use sklears_core::error::SklearsError;
10use sklears_core::traits::{Fit, Transform};
11use std::any::Any;
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use thiserror::Error;
15
16#[derive(Error, Debug)]
18pub enum PluginError {
20 #[error("Plugin not found: {name}")]
21 PluginNotFound { name: String },
22 #[error("Plugin already registered: {name}")]
23 PluginAlreadyRegistered { name: String },
24 #[error("Invalid plugin configuration: {message}")]
25 InvalidConfiguration { message: String },
26 #[error("Plugin initialization failed: {message}")]
27 InitializationFailed { message: String },
28 #[error("Type casting error for plugin: {name}")]
29 TypeCastError { name: String },
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct PluginMetadata {
36 pub name: String,
38 pub version: String,
40 pub description: String,
42 pub author: String,
44 pub supported_kernels: Vec<String>,
46 pub required_parameters: Vec<String>,
48 pub optional_parameters: Vec<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct PluginConfig {
56 pub parameters: HashMap<String, serde_json::Value>,
58 pub random_state: Option<u64>,
60}
61
62impl Default for PluginConfig {
63 fn default() -> Self {
64 Self {
65 parameters: HashMap::new(),
66 random_state: None,
67 }
68 }
69}
70
71pub trait KernelApproximationPlugin: Send + Sync {
73 fn metadata(&self) -> PluginMetadata;
75
76 fn create(
78 &self,
79 config: PluginConfig,
80 ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError>;
81
82 fn validate_config(&self, config: &PluginConfig) -> std::result::Result<(), PluginError>;
84
85 fn default_config(&self) -> PluginConfig;
87}
88
89pub trait KernelApproximationInstance: Send + Sync {
91 fn fit(&mut self, x: &Array2<f64>, y: &()) -> std::result::Result<(), PluginError>;
93
94 fn transform(&self, x: &Array2<f64>) -> std::result::Result<Array2<f64>, PluginError>;
96
97 fn is_fitted(&self) -> bool;
99
100 fn n_output_features(&self) -> Option<usize>;
102
103 fn clone_instance(&self) -> Box<dyn KernelApproximationInstance>;
105
106 fn as_any(&self) -> &dyn Any;
108}
109
110pub struct PluginFactory {
112 plugins: Arc<RwLock<HashMap<String, Box<dyn KernelApproximationPlugin>>>>,
113}
114
115impl Default for PluginFactory {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121impl PluginFactory {
122 pub fn new() -> Self {
124 Self {
125 plugins: Arc::new(RwLock::new(HashMap::new())),
126 }
127 }
128
129 pub fn register_plugin(
131 &self,
132 plugin: Box<dyn KernelApproximationPlugin>,
133 ) -> std::result::Result<(), PluginError> {
134 let metadata = plugin.metadata();
135 let mut plugins = self.plugins.write().unwrap();
136
137 if plugins.contains_key(&metadata.name) {
138 return Err(PluginError::PluginAlreadyRegistered {
139 name: metadata.name,
140 });
141 }
142
143 plugins.insert(metadata.name.clone(), plugin);
144 Ok(())
145 }
146
147 pub fn unregister_plugin(&self, name: &str) -> std::result::Result<(), PluginError> {
149 let mut plugins = self.plugins.write().unwrap();
150 plugins
151 .remove(name)
152 .ok_or_else(|| PluginError::PluginNotFound {
153 name: name.to_string(),
154 })?;
155 Ok(())
156 }
157
158 pub fn get_plugin_metadata(
160 &self,
161 name: &str,
162 ) -> std::result::Result<PluginMetadata, PluginError> {
163 let plugins = self.plugins.read().unwrap();
164 let plugin = plugins
165 .get(name)
166 .ok_or_else(|| PluginError::PluginNotFound {
167 name: name.to_string(),
168 })?;
169 Ok(plugin.metadata())
170 }
171
172 pub fn list_plugins(&self) -> Vec<PluginMetadata> {
174 let plugins = self.plugins.read().unwrap();
175 plugins.values().map(|p| p.metadata()).collect()
176 }
177
178 pub fn create_instance(
180 &self,
181 name: &str,
182 config: PluginConfig,
183 ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError> {
184 let plugins = self.plugins.read().unwrap();
185 let plugin = plugins
186 .get(name)
187 .ok_or_else(|| PluginError::PluginNotFound {
188 name: name.to_string(),
189 })?;
190
191 plugin.validate_config(&config)?;
192 plugin.create(config)
193 }
194
195 pub fn get_default_config(&self, name: &str) -> std::result::Result<PluginConfig, PluginError> {
197 let plugins = self.plugins.read().unwrap();
198 let plugin = plugins
199 .get(name)
200 .ok_or_else(|| PluginError::PluginNotFound {
201 name: name.to_string(),
202 })?;
203 Ok(plugin.default_config())
204 }
205}
206
207pub struct PluginWrapper {
209 instance: Box<dyn KernelApproximationInstance>,
210 metadata: PluginMetadata,
211}
212
213impl PluginWrapper {
214 pub fn new(instance: Box<dyn KernelApproximationInstance>, metadata: PluginMetadata) -> Self {
216 Self { instance, metadata }
217 }
218
219 pub fn metadata(&self) -> &PluginMetadata {
221 &self.metadata
222 }
223
224 pub fn instance(&self) -> &dyn KernelApproximationInstance {
226 self.instance.as_ref()
227 }
228
229 pub fn instance_mut(&mut self) -> &mut dyn KernelApproximationInstance {
231 self.instance.as_mut()
232 }
233}
234
235impl Clone for PluginWrapper {
236 fn clone(&self) -> Self {
237 Self {
238 instance: self.instance.clone_instance(),
239 metadata: self.metadata.clone(),
240 }
241 }
242}
243
244impl Fit<Array2<f64>, ()> for PluginWrapper {
245 type Fitted = FittedPluginWrapper;
246
247 fn fit(mut self, x: &Array2<f64>, y: &()) -> Result<Self::Fitted, SklearsError> {
248 self.instance
249 .fit(x, y)
250 .map_err(|e| SklearsError::InvalidInput(format!("{}", e)))?;
251 Ok(FittedPluginWrapper {
252 instance: self.instance,
253 metadata: self.metadata,
254 })
255 }
256}
257
258pub struct FittedPluginWrapper {
260 instance: Box<dyn KernelApproximationInstance>,
261 metadata: PluginMetadata,
262}
263
264impl FittedPluginWrapper {
265 pub fn metadata(&self) -> &PluginMetadata {
267 &self.metadata
268 }
269
270 pub fn instance(&self) -> &dyn KernelApproximationInstance {
272 self.instance.as_ref()
273 }
274
275 pub fn n_output_features(&self) -> Option<usize> {
277 self.instance.n_output_features()
278 }
279}
280
281impl Clone for FittedPluginWrapper {
282 fn clone(&self) -> Self {
283 Self {
284 instance: self.instance.clone_instance(),
285 metadata: self.metadata.clone(),
286 }
287 }
288}
289
290impl Transform<Array2<f64>, Array2<f64>> for FittedPluginWrapper {
291 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
292 self.instance
293 .transform(x)
294 .map_err(|e| SklearsError::InvalidInput(format!("{}", e)))
295 }
296}
297
298static GLOBAL_FACTORY: std::sync::LazyLock<PluginFactory> =
300 std::sync::LazyLock::new(PluginFactory::new);
301
302pub fn register_global_plugin(
304 plugin: Box<dyn KernelApproximationPlugin>,
305) -> std::result::Result<(), PluginError> {
306 GLOBAL_FACTORY.register_plugin(plugin)
307}
308
309pub fn create_global_plugin_instance(
311 name: &str,
312 config: PluginConfig,
313) -> std::result::Result<PluginWrapper, PluginError> {
314 let instance = GLOBAL_FACTORY.create_instance(name, config)?;
315 let metadata = GLOBAL_FACTORY.get_plugin_metadata(name)?;
316 Ok(PluginWrapper::new(instance, metadata))
317}
318
319pub fn list_global_plugins() -> Vec<PluginMetadata> {
321 GLOBAL_FACTORY.list_plugins()
322}
323
324pub struct LinearKernelPlugin;
326
327impl KernelApproximationPlugin for LinearKernelPlugin {
328 fn metadata(&self) -> PluginMetadata {
329 PluginMetadata {
331 name: "linear_kernel".to_string(),
332 version: "1.0.0".to_string(),
333 description: "Simple linear kernel approximation plugin".to_string(),
334 author: "sklears".to_string(),
335 supported_kernels: vec!["linear".to_string()],
336 required_parameters: vec!["n_components".to_string()],
337 optional_parameters: vec!["normalize".to_string()],
338 }
339 }
340
341 fn create(
342 &self,
343 config: PluginConfig,
344 ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError> {
345 let n_components = config
346 .parameters
347 .get("n_components")
348 .and_then(|v| v.as_u64())
349 .ok_or_else(|| PluginError::InvalidConfiguration {
350 message: "n_components parameter required".to_string(),
351 })? as usize;
352
353 let normalize = config
354 .parameters
355 .get("normalize")
356 .and_then(|v| v.as_bool())
357 .unwrap_or(false);
358
359 Ok(Box::new(LinearKernelInstance {
360 n_components,
361 normalize,
362 projection_matrix: None,
363 }))
364 }
365
366 fn validate_config(&self, config: &PluginConfig) -> std::result::Result<(), PluginError> {
367 if !config.parameters.contains_key("n_components") {
368 return Err(PluginError::InvalidConfiguration {
369 message: "n_components parameter is required".to_string(),
370 });
371 }
372
373 if let Some(n_comp) = config
374 .parameters
375 .get("n_components")
376 .and_then(|v| v.as_u64())
377 {
378 if n_comp == 0 {
379 return Err(PluginError::InvalidConfiguration {
380 message: "n_components must be greater than 0".to_string(),
381 });
382 }
383 } else {
384 return Err(PluginError::InvalidConfiguration {
385 message: "n_components must be a positive integer".to_string(),
386 });
387 }
388
389 Ok(())
390 }
391
392 fn default_config(&self) -> PluginConfig {
393 let mut config = PluginConfig::default();
394 config.parameters.insert(
395 "n_components".to_string(),
396 serde_json::Value::Number(100.into()),
397 );
398 config
399 .parameters
400 .insert("normalize".to_string(), serde_json::Value::Bool(false));
401 config
402 }
403}
404
405pub struct LinearKernelInstance {
407 n_components: usize,
408 normalize: bool,
409 projection_matrix: Option<Array2<f64>>,
410}
411
412impl KernelApproximationInstance for LinearKernelInstance {
413 fn fit(&mut self, x: &Array2<f64>, _y: &()) -> std::result::Result<(), PluginError> {
414 use scirs2_core::random::thread_rng;
415 use scirs2_core::random::{Distribution, StandardNormal};
416
417 let (_, n_features) = x.dim();
418 let mut rng = thread_rng();
419
420 let mut proj_matrix = Array2::zeros((n_features, self.n_components));
422 for elem in proj_matrix.iter_mut() {
423 *elem = rng.sample(StandardNormal);
424 }
425
426 if self.normalize {
427 for j in 0..self.n_components {
429 let mut col = proj_matrix.column_mut(j);
430 let norm = col.mapv(|x: f64| x * x).sum().sqrt();
431 if norm > 1e-8 {
432 col /= norm;
433 }
434 }
435 }
436
437 self.projection_matrix = Some(proj_matrix);
438 Ok(())
439 }
440
441 fn transform(&self, x: &Array2<f64>) -> std::result::Result<Array2<f64>, PluginError> {
442 let proj_matrix =
443 self.projection_matrix
444 .as_ref()
445 .ok_or_else(|| PluginError::InitializationFailed {
446 message: "Plugin not fitted".to_string(),
447 })?;
448
449 Ok(x.dot(proj_matrix))
450 }
451
452 fn is_fitted(&self) -> bool {
453 self.projection_matrix.is_some()
454 }
455
456 fn n_output_features(&self) -> Option<usize> {
457 if self.is_fitted() {
458 Some(self.n_components)
459 } else {
460 None
461 }
462 }
463
464 fn clone_instance(&self) -> Box<dyn KernelApproximationInstance> {
465 Box::new(LinearKernelInstance {
466 n_components: self.n_components,
467 normalize: self.normalize,
468 projection_matrix: self.projection_matrix.clone(),
469 })
470 }
471
472 fn as_any(&self) -> &dyn Any {
473 self
474 }
475}
476
477#[allow(non_snake_case)]
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use scirs2_core::ndarray::array;
482
483 #[test]
484 fn test_plugin_registration() {
485 let factory = PluginFactory::new();
486 let plugin = Box::new(LinearKernelPlugin);
487
488 assert!(factory.register_plugin(plugin).is_ok());
489
490 let plugins = factory.list_plugins();
491 assert_eq!(plugins.len(), 1);
492 assert_eq!(plugins[0].name, "linear_kernel");
493 }
494
495 #[test]
496 fn test_plugin_instance_creation() {
497 let factory = PluginFactory::new();
498 let plugin = Box::new(LinearKernelPlugin);
499 factory.register_plugin(plugin).unwrap();
500
501 let mut config = PluginConfig::default();
502 config.parameters.insert(
503 "n_components".to_string(),
504 serde_json::Value::Number(50.into()),
505 );
506
507 let instance = factory.create_instance("linear_kernel", config);
508 assert!(instance.is_ok());
509 }
510
511 #[test]
512 fn test_plugin_wrapper_fit_transform() {
513 let factory = PluginFactory::new();
514 let plugin = Box::new(LinearKernelPlugin);
515 factory.register_plugin(plugin).unwrap();
516
517 let mut config = PluginConfig::default();
518 config.parameters.insert(
519 "n_components".to_string(),
520 serde_json::Value::Number(30.into()),
521 );
522
523 let instance = factory.create_instance("linear_kernel", config).unwrap();
524 let metadata = factory.get_plugin_metadata("linear_kernel").unwrap();
525 let wrapper = PluginWrapper::new(instance, metadata);
526
527 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
528 let fitted = wrapper.fit(&x, &()).unwrap();
529 let transformed = fitted.transform(&x).unwrap();
530
531 assert_eq!(transformed.shape(), &[3, 30]);
532 }
533
534 #[test]
535 fn test_global_plugin_registry() {
536 let plugin = Box::new(LinearKernelPlugin);
537 assert!(register_global_plugin(plugin).is_ok());
538
539 let plugins = list_global_plugins();
540 assert!(!plugins.is_empty());
541 }
542
543 #[test]
544 fn test_invalid_configuration() {
545 let factory = PluginFactory::new();
546 let plugin = Box::new(LinearKernelPlugin);
547 factory.register_plugin(plugin).unwrap();
548
549 let config = PluginConfig::default(); let result = factory.create_instance("linear_kernel", config);
551 assert!(result.is_err());
552 }
553}