1use crate::data_utils::{convert_array_type, ConvertedData};
2use ndarray::Axis;
3use ndarray::{concatenate, Array2};
4use num_traits::{Float, FromPrimitive, Num};
5use numpy::PyReadonlyArray2;
6use scouter_drift::error::DriftError;
7use scouter_drift::{
8 spc::{generate_alerts, SpcDriftMap, SpcMonitor},
9 CategoricalFeatureHelpers,
10};
11use scouter_types::{
12 create_feature_map,
13 spc::{SpcAlertRule, SpcDriftConfig, SpcDriftProfile, SpcFeatureAlerts},
14};
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18#[derive(Default)]
19pub struct SpcDrifter {
20 monitor: SpcMonitor,
21}
22
23impl SpcDrifter {
24 pub fn new() -> Self {
25 let monitor = SpcMonitor::new();
26 SpcDrifter { monitor }
27 }
28
29 pub fn convert_strings_to_numpy_f32(
30 &mut self,
31 features: Vec<String>,
32 array: Vec<Vec<String>>,
33 drift_profile: SpcDriftProfile,
34 ) -> Result<Array2<f32>, DriftError> {
35 let array = self.monitor.convert_strings_to_ndarray_f32(
36 &features,
37 &array,
38 &drift_profile.config.feature_map,
39 )?;
40
41 Ok(array)
42 }
43
44 pub fn convert_strings_to_numpy_f64(
45 &mut self,
46 features: Vec<String>,
47 array: Vec<Vec<String>>,
48 drift_profile: SpcDriftProfile,
49 ) -> Result<Array2<f64>, DriftError> {
50 let array = self.monitor.convert_strings_to_ndarray_f64(
51 &features,
52 &array,
53 &drift_profile.config.feature_map,
54 )?;
55
56 Ok(array)
57 }
58
59 pub fn generate_alerts(
60 &mut self,
61 drift_array: PyReadonlyArray2<f64>,
62 features: Vec<String>,
63 alert_rule: SpcAlertRule,
64 ) -> Result<SpcFeatureAlerts, DriftError> {
65 let drift_array = drift_array.as_array();
66
67 generate_alerts(&drift_array, &features, &alert_rule)
68 }
69
70 pub fn create_string_drift_profile(
71 &mut self,
72 array: Vec<Vec<String>>,
73 features: Vec<String>,
74 mut drift_config: SpcDriftConfig,
75 ) -> Result<SpcDriftProfile, DriftError> {
76 let feature_map = match create_feature_map(&features, &array) {
77 Ok(feature_map) => feature_map,
78 Err(e) => {
79 return Err(e.into());
80 }
81 };
82
83 drift_config.update_feature_map(feature_map.clone());
84
85 let array = self
86 .monitor
87 .convert_strings_to_ndarray_f32(&features, &array, &feature_map)?;
88
89 self.monitor
90 .create_2d_drift_profile(&features, &array.view(), &drift_config)
91 }
92
93 pub fn create_numeric_drift_profile<F>(
94 &mut self,
95 array: PyReadonlyArray2<F>,
96 features: Vec<String>,
97 drift_config: SpcDriftConfig,
98 ) -> Result<SpcDriftProfile, DriftError>
99 where
100 F: Float
101 + Sync
102 + FromPrimitive
103 + Send
104 + Num
105 + Debug
106 + num_traits::Zero
107 + ndarray::ScalarOperand
108 + numpy::Element,
109 F: Into<f64>,
110 {
111 let array = array.as_array();
112
113 let profile = self
114 .monitor
115 .create_2d_drift_profile(&features, &array, &drift_config)?;
116
117 Ok(profile)
118 }
119
120 pub fn compute_drift(
121 &mut self,
122 data: ConvertedData<'_>,
123 drift_profile: SpcDriftProfile,
124 ) -> Result<SpcDriftMap, DriftError> {
125 let (num_features, num_array, dtype, string_features, string_array) = data;
126 let dtype = dtype.unwrap_or("float32".to_string());
127
128 let mut features = num_features.clone();
129 features.extend(string_features.clone());
130
131 if let Some(string_array) = string_array {
132 if dtype == "float64" {
133 let string_array = self.convert_strings_to_numpy_f64(
134 string_features,
135 string_array,
136 drift_profile.clone(),
137 )?;
138
139 if num_array.is_some() {
140 let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
141 let concatenated =
142 concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
143 Ok(self.monitor.compute_drift(
144 &features,
145 &concatenated.view(),
146 &drift_profile,
147 )?)
148 } else {
149 Ok(self.monitor.compute_drift(
150 &features,
151 &string_array.view(),
152 &drift_profile,
153 )?)
154 }
155 } else {
156 let string_array = self.convert_strings_to_numpy_f32(
157 string_features,
158 string_array,
159 drift_profile.clone(),
160 )?;
161
162 if num_array.is_some() {
163 let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
164 let concatenated =
165 concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
166 Ok(self.monitor.compute_drift(
167 &features,
168 &concatenated.view(),
169 &drift_profile,
170 )?)
171 } else {
172 Ok(self.monitor.compute_drift(
173 &features,
174 &string_array.view(),
175 &drift_profile,
176 )?)
177 }
178 }
179 } else if dtype == "float64" {
180 let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
181 Ok(self
182 .monitor
183 .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
184 } else {
185 let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
186 Ok(self
187 .monitor
188 .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
189 }
190 }
191
192 pub fn create_drift_profile(
193 &mut self,
194 data: ConvertedData<'_>,
195 config: SpcDriftConfig,
196 ) -> Result<SpcDriftProfile, DriftError> {
197 let (num_features, num_array, dtype, string_features, string_array) = data;
198
199 let mut features = HashMap::new();
200 let cfg = config.clone();
201 let mut final_config = cfg.clone();
202
203 if let Some(string_array) = string_array {
204 let profile = self.create_string_drift_profile(
205 string_array,
206 string_features,
207 final_config.clone(),
208 )?;
209 final_config.feature_map = profile.config.feature_map.clone();
210 features.extend(profile.features);
211 }
212
213 if let Some(num_array) = num_array {
214 let dtype = dtype.unwrap();
215 let drift_profile = if dtype == "float64" {
216 let array = convert_array_type::<f64>(num_array, &dtype)?;
217 self.create_numeric_drift_profile(array, num_features, final_config.clone())?
218 } else {
219 let array = convert_array_type::<f32>(num_array, &dtype)?;
220 self.create_numeric_drift_profile(array, num_features, final_config.clone())?
221 };
222 features.extend(drift_profile.features);
223 }
224
225 if config.alert_config.features_to_monitor.is_empty() {
227 final_config.alert_config.features_to_monitor = features.keys().cloned().collect();
228 }
229
230 if let Some(missing_feature) = final_config
232 .alert_config
233 .features_to_monitor
234 .iter()
235 .find(|&key| !features.contains_key(key))
236 {
237 return Err(DriftError::FeatureToMonitorMissingError(
238 missing_feature.to_string(),
239 ));
240 }
241
242 Ok(SpcDriftProfile::new(features, final_config, None))
243 }
244}