1use crate::data_utils::{convert_array_type, ConvertedData};
2use ndarray::Axis;
3use ndarray::{concatenate, Array2};
4use num_traits::{Float, FromPrimitive, Num};
5use numpy::PyReadonlyArray2;
6use scouter_drift::error::DriftError;
7use scouter_drift::{
8 spc::{generate_alerts, SpcDriftMap, SpcMonitor},
9 CategoricalFeatureHelpers,
10};
11use scouter_types::{
12 create_feature_map,
13 spc::{SpcAlertRule, SpcDriftConfig, SpcDriftProfile, SpcFeatureAlerts},
14};
15use std::collections::HashMap;
16use std::fmt::Debug;
17use std::sync::Arc;
18use std::sync::RwLock;
19#[derive(Default)]
20pub struct SpcDrifter {
21 monitor: SpcMonitor,
22}
23
24impl SpcDrifter {
25 pub fn new() -> Self {
26 let monitor = SpcMonitor::new();
27 SpcDrifter { monitor }
28 }
29
30 pub fn convert_strings_to_numpy_f32(
31 &mut self,
32 features: Vec<String>,
33 array: Vec<Vec<String>>,
34 drift_profile: SpcDriftProfile,
35 ) -> Result<Array2<f32>, DriftError> {
36 let array = self.monitor.convert_strings_to_ndarray_f32(
37 &features,
38 &array,
39 &drift_profile.config.feature_map,
40 )?;
41
42 Ok(array)
43 }
44
45 pub fn convert_strings_to_numpy_f64(
46 &mut self,
47 features: Vec<String>,
48 array: Vec<Vec<String>>,
49 drift_profile: SpcDriftProfile,
50 ) -> Result<Array2<f64>, DriftError> {
51 let array = self.monitor.convert_strings_to_ndarray_f64(
52 &features,
53 &array,
54 &drift_profile.config.feature_map,
55 )?;
56
57 Ok(array)
58 }
59
60 pub fn generate_alerts(
61 &mut self,
62 drift_array: PyReadonlyArray2<f64>,
63 features: Vec<String>,
64 alert_rule: SpcAlertRule,
65 ) -> Result<SpcFeatureAlerts, DriftError> {
66 let drift_array = drift_array.as_array();
67
68 generate_alerts(&drift_array, &features, &alert_rule)
69 }
70
71 pub fn create_string_drift_profile(
72 &mut self,
73 array: Vec<Vec<String>>,
74 features: Vec<String>,
75 drift_config: Arc<RwLock<SpcDriftConfig>>,
76 ) -> Result<SpcDriftProfile, DriftError> {
77 let feature_map = match create_feature_map(&features, &array) {
78 Ok(feature_map) => feature_map,
79 Err(e) => {
80 return Err(e.into());
81 }
82 };
83
84 drift_config
85 .write()
86 .unwrap()
87 .update_feature_map(feature_map.clone());
88
89 let array = self
90 .monitor
91 .convert_strings_to_ndarray_f32(&features, &array, &feature_map)?;
92
93 self.monitor.create_2d_drift_profile(
94 &features,
95 &array.view(),
96 &drift_config.read().unwrap(),
97 )
98 }
99
100 pub fn create_numeric_drift_profile<F>(
101 &mut self,
102 array: PyReadonlyArray2<F>,
103 features: Vec<String>,
104 drift_config: &SpcDriftConfig,
105 ) -> Result<SpcDriftProfile, DriftError>
106 where
107 F: Float
108 + Sync
109 + FromPrimitive
110 + Send
111 + Num
112 + Debug
113 + num_traits::Zero
114 + ndarray::ScalarOperand
115 + numpy::Element,
116 F: Into<f64>,
117 {
118 let array = array.as_array();
119
120 let profile = self
121 .monitor
122 .create_2d_drift_profile(&features, &array, drift_config)?;
123
124 Ok(profile)
125 }
126
127 pub fn compute_drift(
128 &mut self,
129 data: ConvertedData<'_>,
130 drift_profile: SpcDriftProfile,
131 ) -> Result<SpcDriftMap, DriftError> {
132 let (num_features, num_array, dtype, string_features, string_array) = data;
133 let dtype = dtype.unwrap_or("float32".to_string());
134
135 let mut features = num_features.clone();
136 features.extend(string_features.clone());
137
138 if let Some(string_array) = string_array {
139 if dtype == "float64" {
140 let string_array = self.convert_strings_to_numpy_f64(
141 string_features,
142 string_array,
143 drift_profile.clone(),
144 )?;
145
146 if num_array.is_some() {
147 let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
148 let concatenated =
149 concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
150 Ok(self.monitor.compute_drift(
151 &features,
152 &concatenated.view(),
153 &drift_profile,
154 )?)
155 } else {
156 Ok(self.monitor.compute_drift(
157 &features,
158 &string_array.view(),
159 &drift_profile,
160 )?)
161 }
162 } else {
163 let string_array = self.convert_strings_to_numpy_f32(
164 string_features,
165 string_array,
166 drift_profile.clone(),
167 )?;
168
169 if num_array.is_some() {
170 let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
171 let concatenated =
172 concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
173 Ok(self.monitor.compute_drift(
174 &features,
175 &concatenated.view(),
176 &drift_profile,
177 )?)
178 } else {
179 Ok(self.monitor.compute_drift(
180 &features,
181 &string_array.view(),
182 &drift_profile,
183 )?)
184 }
185 }
186 } else if dtype == "float64" {
187 let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
188 Ok(self
189 .monitor
190 .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
191 } else {
192 let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
193 Ok(self
194 .monitor
195 .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
196 }
197 }
198
199 pub fn create_drift_profile(
200 &mut self,
201 data: ConvertedData<'_>,
202 config: Arc<RwLock<SpcDriftConfig>>,
203 ) -> Result<SpcDriftProfile, DriftError> {
204 let (num_features, num_array, dtype, string_features, string_array) = data;
205
206 let mut features = HashMap::new();
207
208 if let Some(string_array) = string_array {
209 let profile =
210 self.create_string_drift_profile(string_array, string_features, config.clone())?;
211 features.extend(profile.features);
212 }
213
214 if let Some(num_array) = num_array {
215 let dtype = dtype.unwrap();
216 let drift_profile = {
217 let read_config = config.read().unwrap();
218 if dtype == "float64" {
219 let array = convert_array_type::<f64>(num_array, &dtype)?;
220 self.create_numeric_drift_profile(array, num_features, &read_config)?
221 } else {
222 let array = convert_array_type::<f32>(num_array, &dtype)?;
223 self.create_numeric_drift_profile(array, num_features, &read_config)?
224 }
225 };
226 features.extend(drift_profile.features);
227 }
228
229 {
230 let mut write_config = config.write().unwrap();
231
232 if write_config.alert_config.features_to_monitor.is_empty() {
233 write_config.alert_config.features_to_monitor = features.keys().cloned().collect();
234 }
235
236 if let Some(missing_feature) = write_config
238 .alert_config
239 .features_to_monitor
240 .iter()
241 .find(|&key| !features.contains_key(key))
242 {
243 return Err(DriftError::FeatureToMonitorMissingError(
244 missing_feature.to_string(),
245 ));
246 }
247 }
248
249 let config_clone = config.read().unwrap().clone();
250
251 Ok(SpcDriftProfile::new(features, config_clone))
252 }
253}