sklears_kernel_approximation/
plugin_architecture.rs1use scirs2_core::ndarray::Array2;
8use serde::{Deserialize, Serialize};
9use sklears_core::error::SklearsError;
10use sklears_core::traits::{Fit, Transform};
11use std::any::Any;
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use thiserror::Error;
15
16#[derive(Error, Debug)]
18pub enum PluginError {
20 #[error("Plugin not found: {name}")]
21 PluginNotFound { name: String },
22 #[error("Plugin already registered: {name}")]
23 PluginAlreadyRegistered { name: String },
24 #[error("Invalid plugin configuration: {message}")]
25 InvalidConfiguration { message: String },
26 #[error("Plugin initialization failed: {message}")]
27 InitializationFailed { message: String },
28 #[error("Type casting error for plugin: {name}")]
29 TypeCastError { name: String },
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct PluginMetadata {
36 pub name: String,
38 pub version: String,
40 pub description: String,
42 pub author: String,
44 pub supported_kernels: Vec<String>,
46 pub required_parameters: Vec<String>,
48 pub optional_parameters: Vec<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54#[derive(Default)]
56pub struct PluginConfig {
57 pub parameters: HashMap<String, serde_json::Value>,
59 pub random_state: Option<u64>,
61}
62
63pub trait KernelApproximationPlugin: Send + Sync {
65 fn metadata(&self) -> PluginMetadata;
67
68 fn create(
70 &self,
71 config: PluginConfig,
72 ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError>;
73
74 fn validate_config(&self, config: &PluginConfig) -> std::result::Result<(), PluginError>;
76
77 fn default_config(&self) -> PluginConfig;
79}
80
81pub trait KernelApproximationInstance: Send + Sync {
83 fn fit(&mut self, x: &Array2<f64>, y: &()) -> std::result::Result<(), PluginError>;
85
86 fn transform(&self, x: &Array2<f64>) -> std::result::Result<Array2<f64>, PluginError>;
88
89 fn is_fitted(&self) -> bool;
91
92 fn n_output_features(&self) -> Option<usize>;
94
95 fn clone_instance(&self) -> Box<dyn KernelApproximationInstance>;
97
98 fn as_any(&self) -> &dyn Any;
100}
101
102pub struct PluginFactory {
104 plugins: Arc<RwLock<HashMap<String, Box<dyn KernelApproximationPlugin>>>>,
105}
106
107impl Default for PluginFactory {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113impl PluginFactory {
114 pub fn new() -> Self {
116 Self {
117 plugins: Arc::new(RwLock::new(HashMap::new())),
118 }
119 }
120
121 pub fn register_plugin(
123 &self,
124 plugin: Box<dyn KernelApproximationPlugin>,
125 ) -> std::result::Result<(), PluginError> {
126 let metadata = plugin.metadata();
127 let mut plugins = self.plugins.write().unwrap();
128
129 if plugins.contains_key(&metadata.name) {
130 return Err(PluginError::PluginAlreadyRegistered {
131 name: metadata.name,
132 });
133 }
134
135 plugins.insert(metadata.name.clone(), plugin);
136 Ok(())
137 }
138
139 pub fn unregister_plugin(&self, name: &str) -> std::result::Result<(), PluginError> {
141 let mut plugins = self.plugins.write().unwrap();
142 plugins
143 .remove(name)
144 .ok_or_else(|| PluginError::PluginNotFound {
145 name: name.to_string(),
146 })?;
147 Ok(())
148 }
149
150 pub fn get_plugin_metadata(
152 &self,
153 name: &str,
154 ) -> std::result::Result<PluginMetadata, PluginError> {
155 let plugins = self.plugins.read().unwrap();
156 let plugin = plugins
157 .get(name)
158 .ok_or_else(|| PluginError::PluginNotFound {
159 name: name.to_string(),
160 })?;
161 Ok(plugin.metadata())
162 }
163
164 pub fn list_plugins(&self) -> Vec<PluginMetadata> {
166 let plugins = self.plugins.read().unwrap();
167 plugins.values().map(|p| p.metadata()).collect()
168 }
169
170 pub fn create_instance(
172 &self,
173 name: &str,
174 config: PluginConfig,
175 ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError> {
176 let plugins = self.plugins.read().unwrap();
177 let plugin = plugins
178 .get(name)
179 .ok_or_else(|| PluginError::PluginNotFound {
180 name: name.to_string(),
181 })?;
182
183 plugin.validate_config(&config)?;
184 plugin.create(config)
185 }
186
187 pub fn get_default_config(&self, name: &str) -> std::result::Result<PluginConfig, PluginError> {
189 let plugins = self.plugins.read().unwrap();
190 let plugin = plugins
191 .get(name)
192 .ok_or_else(|| PluginError::PluginNotFound {
193 name: name.to_string(),
194 })?;
195 Ok(plugin.default_config())
196 }
197}
198
199pub struct PluginWrapper {
201 instance: Box<dyn KernelApproximationInstance>,
202 metadata: PluginMetadata,
203}
204
205impl PluginWrapper {
206 pub fn new(instance: Box<dyn KernelApproximationInstance>, metadata: PluginMetadata) -> Self {
208 Self { instance, metadata }
209 }
210
211 pub fn metadata(&self) -> &PluginMetadata {
213 &self.metadata
214 }
215
216 pub fn instance(&self) -> &dyn KernelApproximationInstance {
218 self.instance.as_ref()
219 }
220
221 pub fn instance_mut(&mut self) -> &mut dyn KernelApproximationInstance {
223 self.instance.as_mut()
224 }
225}
226
227impl Clone for PluginWrapper {
228 fn clone(&self) -> Self {
229 Self {
230 instance: self.instance.clone_instance(),
231 metadata: self.metadata.clone(),
232 }
233 }
234}
235
236impl Fit<Array2<f64>, ()> for PluginWrapper {
237 type Fitted = FittedPluginWrapper;
238
239 fn fit(mut self, x: &Array2<f64>, y: &()) -> Result<Self::Fitted, SklearsError> {
240 self.instance
241 .fit(x, y)
242 .map_err(|e| SklearsError::InvalidInput(format!("{}", e)))?;
243 Ok(FittedPluginWrapper {
244 instance: self.instance,
245 metadata: self.metadata,
246 })
247 }
248}
249
250pub struct FittedPluginWrapper {
252 instance: Box<dyn KernelApproximationInstance>,
253 metadata: PluginMetadata,
254}
255
256impl FittedPluginWrapper {
257 pub fn metadata(&self) -> &PluginMetadata {
259 &self.metadata
260 }
261
262 pub fn instance(&self) -> &dyn KernelApproximationInstance {
264 self.instance.as_ref()
265 }
266
267 pub fn n_output_features(&self) -> Option<usize> {
269 self.instance.n_output_features()
270 }
271}
272
273impl Clone for FittedPluginWrapper {
274 fn clone(&self) -> Self {
275 Self {
276 instance: self.instance.clone_instance(),
277 metadata: self.metadata.clone(),
278 }
279 }
280}
281
282impl Transform<Array2<f64>, Array2<f64>> for FittedPluginWrapper {
283 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
284 self.instance
285 .transform(x)
286 .map_err(|e| SklearsError::InvalidInput(format!("{}", e)))
287 }
288}
289
290static GLOBAL_FACTORY: std::sync::LazyLock<PluginFactory> =
292 std::sync::LazyLock::new(PluginFactory::new);
293
294pub fn register_global_plugin(
296 plugin: Box<dyn KernelApproximationPlugin>,
297) -> std::result::Result<(), PluginError> {
298 GLOBAL_FACTORY.register_plugin(plugin)
299}
300
301pub fn create_global_plugin_instance(
303 name: &str,
304 config: PluginConfig,
305) -> std::result::Result<PluginWrapper, PluginError> {
306 let instance = GLOBAL_FACTORY.create_instance(name, config)?;
307 let metadata = GLOBAL_FACTORY.get_plugin_metadata(name)?;
308 Ok(PluginWrapper::new(instance, metadata))
309}
310
311pub fn list_global_plugins() -> Vec<PluginMetadata> {
313 GLOBAL_FACTORY.list_plugins()
314}
315
316pub struct LinearKernelPlugin;
318
319impl KernelApproximationPlugin for LinearKernelPlugin {
320 fn metadata(&self) -> PluginMetadata {
321 PluginMetadata {
322 name: "linear_kernel".to_string(),
323 version: "1.0.0".to_string(),
324 description: "Simple linear kernel approximation plugin".to_string(),
325 author: "sklears".to_string(),
326 supported_kernels: vec!["linear".to_string()],
327 required_parameters: vec!["n_components".to_string()],
328 optional_parameters: vec!["normalize".to_string()],
329 }
330 }
331
332 fn create(
333 &self,
334 config: PluginConfig,
335 ) -> std::result::Result<Box<dyn KernelApproximationInstance>, PluginError> {
336 let n_components = config
337 .parameters
338 .get("n_components")
339 .and_then(|v| v.as_u64())
340 .ok_or_else(|| PluginError::InvalidConfiguration {
341 message: "n_components parameter required".to_string(),
342 })? as usize;
343
344 let normalize = config
345 .parameters
346 .get("normalize")
347 .and_then(|v| v.as_bool())
348 .unwrap_or(false);
349
350 Ok(Box::new(LinearKernelInstance {
351 n_components,
352 normalize,
353 projection_matrix: None,
354 }))
355 }
356
357 fn validate_config(&self, config: &PluginConfig) -> std::result::Result<(), PluginError> {
358 if !config.parameters.contains_key("n_components") {
359 return Err(PluginError::InvalidConfiguration {
360 message: "n_components parameter is required".to_string(),
361 });
362 }
363
364 if let Some(n_comp) = config
365 .parameters
366 .get("n_components")
367 .and_then(|v| v.as_u64())
368 {
369 if n_comp == 0 {
370 return Err(PluginError::InvalidConfiguration {
371 message: "n_components must be greater than 0".to_string(),
372 });
373 }
374 } else {
375 return Err(PluginError::InvalidConfiguration {
376 message: "n_components must be a positive integer".to_string(),
377 });
378 }
379
380 Ok(())
381 }
382
383 fn default_config(&self) -> PluginConfig {
384 let mut config = PluginConfig::default();
385 config.parameters.insert(
386 "n_components".to_string(),
387 serde_json::Value::Number(100.into()),
388 );
389 config
390 .parameters
391 .insert("normalize".to_string(), serde_json::Value::Bool(false));
392 config
393 }
394}
395
396pub struct LinearKernelInstance {
398 n_components: usize,
399 normalize: bool,
400 projection_matrix: Option<Array2<f64>>,
401}
402
403impl KernelApproximationInstance for LinearKernelInstance {
404 fn fit(&mut self, x: &Array2<f64>, _y: &()) -> std::result::Result<(), PluginError> {
405 use scirs2_core::random::thread_rng;
406 use scirs2_core::random::StandardNormal;
407
408 let (_, n_features) = x.dim();
409 let mut rng = thread_rng();
410
411 let mut proj_matrix = Array2::zeros((n_features, self.n_components));
413 for elem in proj_matrix.iter_mut() {
414 *elem = rng.sample(StandardNormal);
415 }
416
417 if self.normalize {
418 for j in 0..self.n_components {
420 let mut col = proj_matrix.column_mut(j);
421 let norm = col.mapv(|x: f64| x * x).sum().sqrt();
422 if norm > 1e-8 {
423 col /= norm;
424 }
425 }
426 }
427
428 self.projection_matrix = Some(proj_matrix);
429 Ok(())
430 }
431
432 fn transform(&self, x: &Array2<f64>) -> std::result::Result<Array2<f64>, PluginError> {
433 let proj_matrix =
434 self.projection_matrix
435 .as_ref()
436 .ok_or_else(|| PluginError::InitializationFailed {
437 message: "Plugin not fitted".to_string(),
438 })?;
439
440 Ok(x.dot(proj_matrix))
441 }
442
443 fn is_fitted(&self) -> bool {
444 self.projection_matrix.is_some()
445 }
446
447 fn n_output_features(&self) -> Option<usize> {
448 if self.is_fitted() {
449 Some(self.n_components)
450 } else {
451 None
452 }
453 }
454
455 fn clone_instance(&self) -> Box<dyn KernelApproximationInstance> {
456 Box::new(LinearKernelInstance {
457 n_components: self.n_components,
458 normalize: self.normalize,
459 projection_matrix: self.projection_matrix.clone(),
460 })
461 }
462
463 fn as_any(&self) -> &dyn Any {
464 self
465 }
466}
467
468#[allow(non_snake_case)]
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use scirs2_core::ndarray::array;
473
474 #[test]
475 fn test_plugin_registration() {
476 let factory = PluginFactory::new();
477 let plugin = Box::new(LinearKernelPlugin);
478
479 assert!(factory.register_plugin(plugin).is_ok());
480
481 let plugins = factory.list_plugins();
482 assert_eq!(plugins.len(), 1);
483 assert_eq!(plugins[0].name, "linear_kernel");
484 }
485
486 #[test]
487 fn test_plugin_instance_creation() {
488 let factory = PluginFactory::new();
489 let plugin = Box::new(LinearKernelPlugin);
490 factory.register_plugin(plugin).unwrap();
491
492 let mut config = PluginConfig::default();
493 config.parameters.insert(
494 "n_components".to_string(),
495 serde_json::Value::Number(50.into()),
496 );
497
498 let instance = factory.create_instance("linear_kernel", config);
499 assert!(instance.is_ok());
500 }
501
502 #[test]
503 fn test_plugin_wrapper_fit_transform() {
504 let factory = PluginFactory::new();
505 let plugin = Box::new(LinearKernelPlugin);
506 factory.register_plugin(plugin).unwrap();
507
508 let mut config = PluginConfig::default();
509 config.parameters.insert(
510 "n_components".to_string(),
511 serde_json::Value::Number(30.into()),
512 );
513
514 let instance = factory.create_instance("linear_kernel", config).unwrap();
515 let metadata = factory.get_plugin_metadata("linear_kernel").unwrap();
516 let wrapper = PluginWrapper::new(instance, metadata);
517
518 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
519 let fitted = wrapper.fit(&x, &()).unwrap();
520 let transformed = fitted.transform(&x).unwrap();
521
522 assert_eq!(transformed.shape(), &[3, 30]);
523 }
524
525 #[test]
526 fn test_global_plugin_registry() {
527 let plugin = Box::new(LinearKernelPlugin);
528 assert!(register_global_plugin(plugin).is_ok());
529
530 let plugins = list_global_plugins();
531 assert!(!plugins.is_empty());
532 }
533
534 #[test]
535 fn test_invalid_configuration() {
536 let factory = PluginFactory::new();
537 let plugin = Box::new(LinearKernelPlugin);
538 factory.register_plugin(plugin).unwrap();
539
540 let config = PluginConfig::default(); let result = factory.create_instance("linear_kernel", config);
542 assert!(result.is_err());
543 }
544}