scouter_client/drifter/
spc.rs

1use crate::data_utils::{convert_array_type, ConvertedData};
2use ndarray::Axis;
3use ndarray::{concatenate, Array2};
4use num_traits::{Float, FromPrimitive, Num};
5use numpy::PyReadonlyArray2;
6use scouter_drift::error::DriftError;
7use scouter_drift::{
8    spc::{generate_alerts, SpcDriftMap, SpcMonitor},
9    CategoricalFeatureHelpers,
10};
11use scouter_types::{
12    create_feature_map,
13    spc::{SpcAlertRule, SpcDriftConfig, SpcDriftProfile, SpcFeatureAlerts},
14};
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18#[derive(Default)]
19pub struct SpcDrifter {
20    monitor: SpcMonitor,
21}
22
23impl SpcDrifter {
24    pub fn new() -> Self {
25        let monitor = SpcMonitor::new();
26        SpcDrifter { monitor }
27    }
28
29    pub fn convert_strings_to_numpy_f32(
30        &mut self,
31        features: Vec<String>,
32        array: Vec<Vec<String>>,
33        drift_profile: SpcDriftProfile,
34    ) -> Result<Array2<f32>, DriftError> {
35        let array = self.monitor.convert_strings_to_ndarray_f32(
36            &features,
37            &array,
38            &drift_profile.config.feature_map,
39        )?;
40
41        Ok(array)
42    }
43
44    pub fn convert_strings_to_numpy_f64(
45        &mut self,
46        features: Vec<String>,
47        array: Vec<Vec<String>>,
48        drift_profile: SpcDriftProfile,
49    ) -> Result<Array2<f64>, DriftError> {
50        let array = self.monitor.convert_strings_to_ndarray_f64(
51            &features,
52            &array,
53            &drift_profile.config.feature_map,
54        )?;
55
56        Ok(array)
57    }
58
59    pub fn generate_alerts(
60        &mut self,
61        drift_array: PyReadonlyArray2<f64>,
62        features: Vec<String>,
63        alert_rule: SpcAlertRule,
64    ) -> Result<SpcFeatureAlerts, DriftError> {
65        let drift_array = drift_array.as_array();
66
67        generate_alerts(&drift_array, &features, &alert_rule)
68    }
69
70    pub fn create_string_drift_profile(
71        &mut self,
72        array: Vec<Vec<String>>,
73        features: Vec<String>,
74        mut drift_config: SpcDriftConfig,
75    ) -> Result<SpcDriftProfile, DriftError> {
76        let feature_map = match create_feature_map(&features, &array) {
77            Ok(feature_map) => feature_map,
78            Err(e) => {
79                return Err(e.into());
80            }
81        };
82
83        drift_config.update_feature_map(feature_map.clone());
84
85        let array = self
86            .monitor
87            .convert_strings_to_ndarray_f32(&features, &array, &feature_map)?;
88
89        self.monitor
90            .create_2d_drift_profile(&features, &array.view(), &drift_config)
91    }
92
93    pub fn create_numeric_drift_profile<F>(
94        &mut self,
95        array: PyReadonlyArray2<F>,
96        features: Vec<String>,
97        drift_config: SpcDriftConfig,
98    ) -> Result<SpcDriftProfile, DriftError>
99    where
100        F: Float
101            + Sync
102            + FromPrimitive
103            + Send
104            + Num
105            + Debug
106            + num_traits::Zero
107            + ndarray::ScalarOperand
108            + numpy::Element,
109        F: Into<f64>,
110    {
111        let array = array.as_array();
112
113        let profile = self
114            .monitor
115            .create_2d_drift_profile(&features, &array, &drift_config)?;
116
117        Ok(profile)
118    }
119
120    pub fn compute_drift(
121        &mut self,
122        data: ConvertedData<'_>,
123        drift_profile: SpcDriftProfile,
124    ) -> Result<SpcDriftMap, DriftError> {
125        let (num_features, num_array, dtype, string_features, string_array) = data;
126        let dtype = dtype.unwrap_or("float32".to_string());
127
128        let mut features = num_features.clone();
129        features.extend(string_features.clone());
130
131        if let Some(string_array) = string_array {
132            if dtype == "float64" {
133                let string_array = self.convert_strings_to_numpy_f64(
134                    string_features,
135                    string_array,
136                    drift_profile.clone(),
137                )?;
138
139                if num_array.is_some() {
140                    let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
141                    let concatenated =
142                        concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
143                    Ok(self.monitor.compute_drift(
144                        &features,
145                        &concatenated.view(),
146                        &drift_profile,
147                    )?)
148                } else {
149                    Ok(self.monitor.compute_drift(
150                        &features,
151                        &string_array.view(),
152                        &drift_profile,
153                    )?)
154                }
155            } else {
156                let string_array = self.convert_strings_to_numpy_f32(
157                    string_features,
158                    string_array,
159                    drift_profile.clone(),
160                )?;
161
162                if num_array.is_some() {
163                    let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
164                    let concatenated =
165                        concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
166                    Ok(self.monitor.compute_drift(
167                        &features,
168                        &concatenated.view(),
169                        &drift_profile,
170                    )?)
171                } else {
172                    Ok(self.monitor.compute_drift(
173                        &features,
174                        &string_array.view(),
175                        &drift_profile,
176                    )?)
177                }
178            }
179        } else if dtype == "float64" {
180            let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
181            Ok(self
182                .monitor
183                .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
184        } else {
185            let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
186            Ok(self
187                .monitor
188                .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
189        }
190    }
191
192    pub fn create_drift_profile(
193        &mut self,
194        data: ConvertedData<'_>,
195        config: SpcDriftConfig,
196    ) -> Result<SpcDriftProfile, DriftError> {
197        let (num_features, num_array, dtype, string_features, string_array) = data;
198
199        let mut features = HashMap::new();
200        let cfg = config.clone();
201        let mut final_config = cfg.clone();
202
203        if let Some(string_array) = string_array {
204            let profile = self.create_string_drift_profile(
205                string_array,
206                string_features,
207                final_config.clone(),
208            )?;
209            final_config.feature_map = profile.config.feature_map.clone();
210            features.extend(profile.features);
211        }
212
213        if let Some(num_array) = num_array {
214            let dtype = dtype.unwrap();
215            let drift_profile = if dtype == "float64" {
216                let array = convert_array_type::<f64>(num_array, &dtype)?;
217                self.create_numeric_drift_profile(array, num_features, final_config.clone())?
218            } else {
219                let array = convert_array_type::<f32>(num_array, &dtype)?;
220                self.create_numeric_drift_profile(array, num_features, final_config.clone())?
221            };
222            features.extend(drift_profile.features);
223        }
224
225        // if config.features_to_monitor is empty, set it to all features
226        if config.alert_config.features_to_monitor.is_empty() {
227            final_config.alert_config.features_to_monitor = features.keys().cloned().collect();
228        }
229
230        // Validate features_to_monitor
231        if let Some(missing_feature) = final_config
232            .alert_config
233            .features_to_monitor
234            .iter()
235            .find(|&key| !features.contains_key(key))
236        {
237            return Err(DriftError::FeatureToMonitorMissingError(
238                missing_feature.to_string(),
239            ));
240        }
241
242        Ok(SpcDriftProfile::new(features, final_config, None))
243    }
244}