Skip to main content

sklears_compose/plugin_architecture/
examplescaler_traits.rs

1//! # ExampleScaler - Trait Implementations
2//!
3//! This module contains trait implementations for `ExampleScaler`.
4//!
5//! ## Implemented Traits
6//!
7//! - `PluginComponent`
8//! - `PluginTransformer`
9//!
10//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
11
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
13use sklears_core::{
14    error::{Result as SklResult, SklearsError},
15    traits::Estimator,
16    types::Float,
17};
18use std::any::Any;
19
20use super::functions::{PluginComponent, PluginTransformer};
21use super::types::{ComponentConfig, ComponentContext, ExampleScaler};
22
23impl PluginComponent for ExampleScaler {
24    fn component_type(&self) -> &'static str {
25        "example_scaler"
26    }
27    fn config(&self) -> &ComponentConfig {
28        &self.config
29    }
30    fn initialize(&mut self, _context: &ComponentContext) -> SklResult<()> {
31        Ok(())
32    }
33    fn clone_component(&self) -> Box<dyn PluginComponent> {
34        Box::new(self.clone())
35    }
36    fn as_any(&self) -> &dyn Any {
37        self
38    }
39    fn as_any_mut(&mut self) -> &mut dyn Any {
40        self
41    }
42}
43
44impl PluginTransformer for ExampleScaler {
45    fn fit(
46        &mut self,
47        _x: &ArrayView2<'_, Float>,
48        _y: Option<&ArrayView1<'_, Float>>,
49    ) -> SklResult<()> {
50        self.fitted = true;
51        Ok(())
52    }
53    fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
54        if !self.fitted {
55            return Err(SklearsError::InvalidOperation(
56                "Transformer must be fitted before transform".to_string(),
57            ));
58        }
59        Ok(x.mapv(|v| v * self.scale_factor))
60    }
61    fn is_fitted(&self) -> bool {
62        self.fitted
63    }
64    fn get_feature_names_out(&self, input_features: Option<&[String]>) -> Vec<String> {
65        match input_features {
66            Some(features) => features.iter().map(|f| format!("scaled_{f}")).collect(),
67            None => (0..10).map(|i| format!("scaled_feature_{i}")).collect(),
68        }
69    }
70}