sklears_core/plugin/core_traits.rs
1//! Core Plugin Traits
2//!
3//! This module defines the fundamental traits that all plugins must implement,
4//! providing the interface for different types of machine learning algorithms
5//! and transformations in the sklears plugin system.
6
7use crate::error::Result;
8use crate::traits::{Predict, Transform};
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11use std::fmt::Debug;
12
13// Re-export types that will be used by traits
14use super::types_config::{PluginConfig, PluginMetadata, PluginParameter};
15
16/// Core trait that all plugins must implement
17///
18/// This trait defines the fundamental interface that every plugin in the sklears
19/// ecosystem must provide. It ensures consistency across all plugin types and
20/// enables the plugin system to manage plugins in a type-safe manner.
21///
22/// # Examples
23///
24/// ```rust,ignore
25/// use sklears_core::plugin::{Plugin, PluginMetadata, PluginConfig};
26/// use sklears_core::error::Result;
27/// use std::any::{Any, TypeId};
28///
29/// #[derive(Debug)]
30/// struct MyPlugin {
31/// name: String,
32/// }
33///
34/// impl Plugin for MyPlugin {
35/// fn id(&self) -> &str {
36/// &self.name
37/// }
38///
39/// fn metadata(&self) -> PluginMetadata {
40/// PluginMetadata::default()
41/// }
42///
43/// fn initialize(&mut self, _config: &PluginConfig) -> Result<()> {
44/// Ok(())
45/// }
46///
47/// fn is_compatible(&self, _input_type: TypeId) -> bool {
48/// true
49/// }
50///
51/// fn as_any(&self) -> &dyn Any {
52/// self
53/// }
54///
55/// fn as_any_mut(&mut self) -> &mut dyn Any {
56/// self
57/// }
58///
59/// fn validate_config(&self, _config: &PluginConfig) -> Result<()> {
60/// Ok(())
61/// }
62///
63/// fn cleanup(&mut self) -> Result<()> {
64/// Ok(())
65/// }
66/// }
67/// ```
68pub trait Plugin: Send + Sync + Debug {
69 /// Unique identifier for the plugin
70 ///
71 /// This should be a unique string that identifies the plugin within
72 /// the system. It's used for plugin discovery and registration.
73 fn id(&self) -> &str;
74
75 /// Plugin metadata
76 ///
77 /// Returns comprehensive metadata about the plugin including name,
78 /// version, description, capabilities, and dependencies.
79 fn metadata(&self) -> PluginMetadata;
80
81 /// Initialize the plugin with configuration
82 ///
83 /// This method is called when the plugin is loaded and provides
84 /// the plugin with its configuration. The plugin should perform
85 /// any necessary initialization here.
86 fn initialize(&mut self, config: &PluginConfig) -> Result<()>;
87
88 /// Check if the plugin is compatible with the given input type
89 ///
90 /// This method allows the plugin system to determine if a plugin
91 /// can handle a particular data type before attempting to use it.
92 fn is_compatible(&self, input_type: TypeId) -> bool;
93
94 /// Get the plugin as Any for downcasting
95 ///
96 /// This enables type-safe downcasting to the concrete plugin type
97 /// when needed for specialized operations.
98 fn as_any(&self) -> &dyn Any;
99
100 /// Get the plugin as mutable Any for downcasting
101 ///
102 /// This enables mutable access to the concrete plugin type
103 /// when needed for specialized operations.
104 fn as_any_mut(&mut self) -> &mut dyn Any;
105
106 /// Validate plugin configuration
107 ///
108 /// This method should check if the provided configuration is valid
109 /// for this plugin. It's called before initialization to catch
110 /// configuration errors early.
111 fn validate_config(&self, config: &PluginConfig) -> Result<()>;
112
113 /// Cleanup resources when plugin is unloaded
114 ///
115 /// This method is called when the plugin is being unloaded and
116 /// should clean up any resources the plugin has allocated.
117 fn cleanup(&mut self) -> Result<()>;
118}
119
120/// Trait for algorithm plugins that can fit and predict
121///
122/// This trait defines the interface for machine learning algorithms that
123/// follow the fit/predict pattern. It extends the base Plugin trait with
124/// algorithm-specific functionality.
125///
126/// # Type Parameters
127///
128/// * `X` - The input feature type (e.g., `Array2<f64>`)
129/// * `Y` - The target/label type (e.g., `Array1<f64>`)
130/// * `Output` - The prediction output type
131///
132/// # Examples
133///
134/// ```rust,ignore
135/// use sklears_core::plugin::{Plugin, AlgorithmPlugin, PluginParameter};
136/// use sklears_core::traits::Predict;
137/// use sklears_core::error::Result;
138/// use std::collections::HashMap;
139///
140/// #[derive(Debug)]
141/// struct LinearRegression {
142/// // algorithm parameters
143/// }
144///
145/// struct FittedLinearRegression {
146/// // fitted model state
147/// }
148///
149/// impl Predict<Vec<f64>, Vec<f64>> for FittedLinearRegression {
150/// fn predict(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
151/// // prediction implementation
152/// Ok(x.clone())
153/// }
154/// }
155///
156/// impl AlgorithmPlugin<Vec<f64>, Vec<f64>, Vec<f64>> for LinearRegression {
157/// type Fitted = FittedLinearRegression;
158///
159/// fn fit(&self, x: &Vec<f64>, y: &Vec<f64>) -> Result<Self::Fitted> {
160/// Ok(FittedLinearRegression {})
161/// }
162///
163/// fn predict(&self, fitted: &Self::Fitted, x: &Vec<f64>) -> Result<Vec<f64>> {
164/// fitted.predict(x)
165/// }
166///
167/// fn get_parameters(&self) -> HashMap<String, PluginParameter> {
168/// HashMap::new()
169/// }
170///
171/// fn set_parameters(&mut self, _params: HashMap<String, PluginParameter>) -> Result<()> {
172/// Ok(())
173/// }
174/// }
175/// ```
176pub trait AlgorithmPlugin<X, Y, Output>: Plugin + Send + Sync {
177 /// The fitted model type
178 ///
179 /// This type represents the state of the algorithm after training.
180 /// It must implement the Predict trait to enable predictions.
181 type Fitted: Predict<X, Output> + Send + Sync;
182
183 /// Fit the algorithm to training data
184 ///
185 /// This method trains the algorithm on the provided data and returns
186 /// a fitted model that can be used for predictions.
187 ///
188 /// # Arguments
189 ///
190 /// * `x` - Training features
191 /// * `y` - Training targets/labels
192 ///
193 /// # Returns
194 ///
195 /// A fitted model instance or an error if training fails.
196 fn fit(&self, x: &X, y: &Y) -> Result<Self::Fitted>;
197
198 /// Make predictions using the fitted model
199 ///
200 /// This method uses the fitted model to make predictions on new data.
201 ///
202 /// # Arguments
203 ///
204 /// * `fitted` - The fitted model from the fit method
205 /// * `x` - Input features for prediction
206 ///
207 /// # Returns
208 ///
209 /// Predictions or an error if prediction fails.
210 fn predict(&self, fitted: &Self::Fitted, x: &X) -> Result<Output>;
211
212 /// Get algorithm-specific parameters
213 ///
214 /// Returns a map of all configurable parameters for this algorithm.
215 /// This enables introspection and parameter tuning.
216 fn get_parameters(&self) -> HashMap<String, PluginParameter>;
217
218 /// Set algorithm-specific parameters
219 ///
220 /// Allows updating the algorithm's parameters. The algorithm should
221 /// validate that the provided parameters are valid.
222 ///
223 /// # Arguments
224 ///
225 /// * `params` - Map of parameter names to values
226 ///
227 /// # Returns
228 ///
229 /// Ok(()) if parameters were set successfully, or an error if
230 /// any parameter is invalid.
231 fn set_parameters(&mut self, params: HashMap<String, PluginParameter>) -> Result<()>;
232}
233
234/// Trait for transformer plugins
235///
236/// This trait defines the interface for data transformation algorithms
237/// that can fit to data and then transform new data using the learned
238/// transformation.
239///
240/// # Type Parameters
241///
242/// * `X` - The input data type
243/// * `Output` - The transformed output type (defaults to X)
244///
245/// # Examples
246///
247/// ```rust,ignore
248/// use sklears_core::plugin::{Plugin, TransformerPlugin};
249/// use sklears_core::traits::Transform;
250/// use sklears_core::error::Result;
251///
252/// #[derive(Debug)]
253/// struct StandardScaler {
254/// // scaler parameters
255/// }
256///
257/// struct FittedStandardScaler {
258/// // fitted scaler state
259/// }
260///
261/// impl Transform<Vec<f64>, Vec<f64>> for FittedStandardScaler {
262/// fn transform(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
263/// // transformation implementation
264/// Ok(x.clone())
265/// }
266/// }
267///
268/// impl TransformerPlugin<Vec<f64>, Vec<f64>> for StandardScaler {
269/// type Fitted = FittedStandardScaler;
270///
271/// fn fit_transform(&self, x: &Vec<f64>) -> Result<(Self::Fitted, Vec<f64>)> {
272/// let fitted = FittedStandardScaler {};
273/// let transformed = fitted.transform(x)?;
274/// Ok((fitted, transformed))
275/// }
276///
277/// fn transform(&self, fitted: &Self::Fitted, x: &Vec<f64>) -> Result<Vec<f64>> {
278/// fitted.transform(x)
279/// }
280/// }
281/// ```
282pub trait TransformerPlugin<X, Output = X>: Plugin + Send + Sync {
283 /// The fitted transformer type
284 ///
285 /// This type represents the state of the transformer after fitting.
286 /// It must implement the Transform trait to enable transformations.
287 type Fitted: Transform<X, Output> + Send + Sync;
288
289 /// Fit the transformer to data and return the fitted transformer along with transformed data
290 ///
291 /// This method fits the transformer to the input data and immediately
292 /// applies the transformation, returning both the fitted transformer
293 /// and the transformed data.
294 ///
295 /// # Arguments
296 ///
297 /// * `x` - Input data to fit and transform
298 ///
299 /// # Returns
300 ///
301 /// A tuple of (fitted transformer, transformed data) or an error.
302 fn fit_transform(&self, x: &X) -> Result<(Self::Fitted, Output)>;
303
304 /// Transform data using the fitted transformer
305 ///
306 /// This method applies the previously fitted transformation to new data.
307 ///
308 /// # Arguments
309 ///
310 /// * `fitted` - The fitted transformer from fit_transform
311 /// * `x` - Input data to transform
312 ///
313 /// # Returns
314 ///
315 /// Transformed data or an error if transformation fails.
316 fn transform(&self, fitted: &Self::Fitted, x: &X) -> Result<Output>;
317}
318
319/// Trait for clustering plugins
320///
321/// This trait defines the interface for clustering algorithms that
322/// can group data points into clusters and provide cluster information.
323///
324/// # Type Parameters
325///
326/// * `X` - The input data type
327///
328/// # Examples
329///
330/// ```rust,ignore
331/// use sklears_core::plugin::{Plugin, ClusteringPlugin};
332/// use sklears_core::error::Result;
333/// use std::collections::HashMap;
334///
335/// #[derive(Debug)]
336/// struct KMeans {
337/// n_clusters: usize,
338/// }
339///
340/// impl ClusteringPlugin<Vec<Vec<f64>>> for KMeans {
341/// type Labels = Vec<usize>;
342///
343/// fn fit_predict(&self, x: &Vec<Vec<f64>>) -> Result<Self::Labels> {
344/// // clustering implementation
345/// Ok(vec![0; x.len()])
346/// }
347///
348/// fn cluster_centers(&self) -> Option<Vec<Vec<f64>>> {
349/// // return cluster centers if available
350/// None
351/// }
352///
353/// fn cluster_stats(&self) -> HashMap<String, f64> {
354/// // return clustering statistics
355/// HashMap::new()
356/// }
357/// }
358/// ```
359pub trait ClusteringPlugin<X>: Plugin + Send + Sync {
360 /// The cluster labels type
361 ///
362 /// This type represents the cluster assignments for each data point.
363 /// Typically this would be `Vec<usize>` for integer cluster labels.
364 type Labels;
365
366 /// Fit the clustering algorithm and return cluster assignments
367 ///
368 /// This method performs clustering on the input data and returns
369 /// the cluster assignments for each data point.
370 ///
371 /// # Arguments
372 ///
373 /// * `x` - Input data to cluster
374 ///
375 /// # Returns
376 ///
377 /// Cluster labels for each data point or an error.
378 fn fit_predict(&self, x: &X) -> Result<Self::Labels>;
379
380 /// Get cluster centers (if applicable)
381 ///
382 /// For algorithms that compute explicit cluster centers (like K-means),
383 /// this method returns the center points. Returns None if the algorithm
384 /// doesn't compute explicit centers.
385 fn cluster_centers(&self) -> Option<X>;
386
387 /// Get cluster statistics
388 ///
389 /// Returns various statistics about the clustering result, such as
390 /// inertia, silhouette score, number of clusters, etc.
391 fn cluster_stats(&self) -> HashMap<String, f64>;
392}