scouter_client/drifter/
spc.rs

1use crate::data_utils::{convert_array_type, ConvertedData};
2use ndarray::Axis;
3use ndarray::{concatenate, Array2};
4use num_traits::{Float, FromPrimitive, Num};
5use numpy::PyReadonlyArray2;
6use scouter_drift::error::DriftError;
7use scouter_drift::{
8    spc::{generate_alerts, SpcDriftMap, SpcMonitor},
9    CategoricalFeatureHelpers,
10};
11use scouter_types::{
12    create_feature_map,
13    spc::{SpcAlertRule, SpcDriftConfig, SpcDriftProfile, SpcFeatureAlerts},
14};
15use std::collections::HashMap;
16use std::fmt::Debug;
17use std::sync::Arc;
18use std::sync::RwLock;
19#[derive(Default)]
20pub struct SpcDrifter {
21    monitor: SpcMonitor,
22}
23
24impl SpcDrifter {
25    pub fn new() -> Self {
26        let monitor = SpcMonitor::new();
27        SpcDrifter { monitor }
28    }
29
30    pub fn convert_strings_to_numpy_f32(
31        &mut self,
32        features: Vec<String>,
33        array: Vec<Vec<String>>,
34        drift_profile: SpcDriftProfile,
35    ) -> Result<Array2<f32>, DriftError> {
36        let array = self.monitor.convert_strings_to_ndarray_f32(
37            &features,
38            &array,
39            &drift_profile.config.feature_map,
40        )?;
41
42        Ok(array)
43    }
44
45    pub fn convert_strings_to_numpy_f64(
46        &mut self,
47        features: Vec<String>,
48        array: Vec<Vec<String>>,
49        drift_profile: SpcDriftProfile,
50    ) -> Result<Array2<f64>, DriftError> {
51        let array = self.monitor.convert_strings_to_ndarray_f64(
52            &features,
53            &array,
54            &drift_profile.config.feature_map,
55        )?;
56
57        Ok(array)
58    }
59
60    pub fn generate_alerts(
61        &mut self,
62        drift_array: PyReadonlyArray2<f64>,
63        features: Vec<String>,
64        alert_rule: SpcAlertRule,
65    ) -> Result<SpcFeatureAlerts, DriftError> {
66        let drift_array = drift_array.as_array();
67
68        generate_alerts(&drift_array, &features, &alert_rule)
69    }
70
71    pub fn create_string_drift_profile(
72        &mut self,
73        array: Vec<Vec<String>>,
74        features: Vec<String>,
75        drift_config: Arc<RwLock<SpcDriftConfig>>,
76    ) -> Result<SpcDriftProfile, DriftError> {
77        let feature_map = match create_feature_map(&features, &array) {
78            Ok(feature_map) => feature_map,
79            Err(e) => {
80                return Err(e.into());
81            }
82        };
83
84        drift_config
85            .write()
86            .unwrap()
87            .update_feature_map(feature_map.clone());
88
89        let array = self
90            .monitor
91            .convert_strings_to_ndarray_f32(&features, &array, &feature_map)?;
92
93        self.monitor.create_2d_drift_profile(
94            &features,
95            &array.view(),
96            &drift_config.read().unwrap(),
97        )
98    }
99
100    pub fn create_numeric_drift_profile<F>(
101        &mut self,
102        array: PyReadonlyArray2<F>,
103        features: Vec<String>,
104        drift_config: &SpcDriftConfig,
105    ) -> Result<SpcDriftProfile, DriftError>
106    where
107        F: Float
108            + Sync
109            + FromPrimitive
110            + Send
111            + Num
112            + Debug
113            + num_traits::Zero
114            + ndarray::ScalarOperand
115            + numpy::Element,
116        F: Into<f64>,
117    {
118        let array = array.as_array();
119
120        let profile = self
121            .monitor
122            .create_2d_drift_profile(&features, &array, drift_config)?;
123
124        Ok(profile)
125    }
126
127    pub fn compute_drift(
128        &mut self,
129        data: ConvertedData<'_>,
130        drift_profile: SpcDriftProfile,
131    ) -> Result<SpcDriftMap, DriftError> {
132        let (num_features, num_array, dtype, string_features, string_array) = data;
133        let dtype = dtype.unwrap_or("float32".to_string());
134
135        let mut features = num_features.clone();
136        features.extend(string_features.clone());
137
138        if let Some(string_array) = string_array {
139            if dtype == "float64" {
140                let string_array = self.convert_strings_to_numpy_f64(
141                    string_features,
142                    string_array,
143                    drift_profile.clone(),
144                )?;
145
146                if num_array.is_some() {
147                    let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
148                    let concatenated =
149                        concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
150                    Ok(self.monitor.compute_drift(
151                        &features,
152                        &concatenated.view(),
153                        &drift_profile,
154                    )?)
155                } else {
156                    Ok(self.monitor.compute_drift(
157                        &features,
158                        &string_array.view(),
159                        &drift_profile,
160                    )?)
161                }
162            } else {
163                let string_array = self.convert_strings_to_numpy_f32(
164                    string_features,
165                    string_array,
166                    drift_profile.clone(),
167                )?;
168
169                if num_array.is_some() {
170                    let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
171                    let concatenated =
172                        concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
173                    Ok(self.monitor.compute_drift(
174                        &features,
175                        &concatenated.view(),
176                        &drift_profile,
177                    )?)
178                } else {
179                    Ok(self.monitor.compute_drift(
180                        &features,
181                        &string_array.view(),
182                        &drift_profile,
183                    )?)
184                }
185            }
186        } else if dtype == "float64" {
187            let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
188            Ok(self
189                .monitor
190                .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
191        } else {
192            let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
193            Ok(self
194                .monitor
195                .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
196        }
197    }
198
199    pub fn create_drift_profile(
200        &mut self,
201        data: ConvertedData<'_>,
202        config: Arc<RwLock<SpcDriftConfig>>,
203    ) -> Result<SpcDriftProfile, DriftError> {
204        let (num_features, num_array, dtype, string_features, string_array) = data;
205
206        let mut features = HashMap::new();
207
208        if let Some(string_array) = string_array {
209            let profile =
210                self.create_string_drift_profile(string_array, string_features, config.clone())?;
211            features.extend(profile.features);
212        }
213
214        if let Some(num_array) = num_array {
215            let dtype = dtype.unwrap();
216            let drift_profile = {
217                let read_config = config.read().unwrap();
218                if dtype == "float64" {
219                    let array = convert_array_type::<f64>(num_array, &dtype)?;
220                    self.create_numeric_drift_profile(array, num_features, &read_config)?
221                } else {
222                    let array = convert_array_type::<f32>(num_array, &dtype)?;
223                    self.create_numeric_drift_profile(array, num_features, &read_config)?
224                }
225            };
226            features.extend(drift_profile.features);
227        }
228
229        {
230            let mut write_config = config.write().unwrap();
231
232            if write_config.alert_config.features_to_monitor.is_empty() {
233                write_config.alert_config.features_to_monitor = features.keys().cloned().collect();
234            }
235
236            // Validate features_to_monitor
237            if let Some(missing_feature) = write_config
238                .alert_config
239                .features_to_monitor
240                .iter()
241                .find(|&key| !features.contains_key(key))
242            {
243                return Err(DriftError::FeatureToMonitorMissingError(
244                    missing_feature.to_string(),
245                ));
246            }
247        }
248
249        let config_clone = config.read().unwrap().clone();
250
251        Ok(SpcDriftProfile::new(features, config_clone))
252    }
253}