1use crate::data_utils::{convert_array_type, ConvertedData};
2use ndarray::{concatenate, Array2, Axis};
3use num_traits::{Float, FromPrimitive};
4use numpy::PyReadonlyArray2;
5use scouter_drift::error::DriftError;
6use scouter_drift::{psi::PsiMonitor, CategoricalFeatureHelpers};
7use scouter_types::psi::{PsiDriftConfig, PsiDriftMap, PsiDriftProfile};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use tracing::instrument;
11#[derive(Default)]
12pub struct PsiDrifter {
13 monitor: PsiMonitor,
14}
15
16impl PsiDrifter {
17 pub fn new() -> Self {
18 let monitor = PsiMonitor::new();
19 PsiDrifter { monitor }
20 }
21
22 pub fn convert_strings_to_numpy_f32(
23 &mut self,
24 features: Vec<String>,
25 array: Vec<Vec<String>>,
26 drift_profile: PsiDriftProfile,
27 ) -> Result<Array2<f32>, DriftError> {
28 let array = self.monitor.convert_strings_to_ndarray_f32(
29 &features,
30 &array,
31 &drift_profile.config.feature_map,
32 )?;
33
34 Ok(array)
35 }
36
37 pub fn convert_strings_to_numpy_f64(
38 &mut self,
39 features: Vec<String>,
40 array: Vec<Vec<String>>,
41 drift_profile: PsiDriftProfile,
42 ) -> Result<Array2<f64>, DriftError> {
43 let array = self.monitor.convert_strings_to_ndarray_f64(
44 &features,
45 &array,
46 &drift_profile.config.feature_map,
47 )?;
48
49 Ok(array)
50 }
51
52 #[instrument(skip_all)]
53 pub fn create_string_drift_profile(
54 &mut self,
55 array: Vec<Vec<String>>,
56 features: Vec<String>,
57 drift_config: Arc<RwLock<PsiDriftConfig>>,
58 ) -> Result<PsiDriftProfile, DriftError> {
59 let feature_map = self.monitor.create_feature_map(&features, &array)?;
60
61 drift_config
62 .write()
63 .unwrap()
64 .update_feature_map(feature_map.clone());
65
66 let array = self
67 .monitor
68 .convert_strings_to_ndarray_f32(&features, &array, &feature_map)?;
69
70 let profile = self.monitor.create_2d_drift_profile(
71 &features,
72 &array.view(),
73 &drift_config.read().unwrap(),
74 )?;
75
76 Ok(profile)
77 }
78
79 pub fn create_numeric_drift_profile<F>(
80 &mut self,
81 array: PyReadonlyArray2<F>,
82 features: Vec<String>,
83 drift_config: &PsiDriftConfig,
84 ) -> Result<PsiDriftProfile, DriftError>
85 where
86 F: Float + Sync + FromPrimitive + Default,
87 F: Into<f64>,
88 F: numpy::Element,
89 {
90 let array = array.as_array();
91
92 let profile = self
93 .monitor
94 .create_2d_drift_profile(&features, &array, drift_config)?;
95
96 Ok(profile)
97 }
98
99 pub fn create_drift_profile(
100 &mut self,
101 data: ConvertedData<'_>,
102 config: Arc<RwLock<PsiDriftConfig>>,
103 ) -> Result<PsiDriftProfile, DriftError> {
104 let (num_features, num_array, dtype, string_features, string_array) = data;
105
106 {
108 let read_config = config.read().unwrap();
109 if let Some(categorical_features) = read_config.categorical_features.as_ref() {
110 if let Some(missing_feature) = categorical_features
111 .iter()
112 .find(|&key| !num_features.contains(key) && !string_features.contains(key))
113 {
114 return Err(DriftError::CategoricalFeatureMissingError(
115 missing_feature.to_string(),
116 ));
117 }
118 }
119 }
120
121 let mut features = HashMap::new();
122
123 if let Some(string_array) = string_array {
124 let profile =
125 self.create_string_drift_profile(string_array, string_features, config.clone())?;
126 features.extend(profile.features);
127 }
128
129 if let Some(num_array) = num_array {
130 let dtype = dtype.unwrap();
131 let drift_profile = {
132 let read_config = config.read().unwrap();
133 if dtype == "float64" {
134 let array = convert_array_type::<f64>(num_array, &dtype)?;
135 self.create_numeric_drift_profile(array, num_features, &read_config)?
136 } else {
137 let array = convert_array_type::<f32>(num_array, &dtype)?;
138 self.create_numeric_drift_profile(array, num_features, &read_config)?
139 }
140 };
141 features.extend(drift_profile.features);
142 }
143
144 {
146 let mut write_config = config.write().unwrap();
147 if write_config.alert_config.features_to_monitor.is_empty() {
148 write_config.alert_config.features_to_monitor = features.keys().cloned().collect();
149 }
150
151 if let Some(missing_feature) = write_config
153 .alert_config
154 .features_to_monitor
155 .iter()
156 .find(|&key| !features.contains_key(key))
157 {
158 return Err(DriftError::FeatureToMonitorMissingError(
159 missing_feature.to_string(),
160 ));
161 }
162 }
163
164 let config_clone = config.read().unwrap().clone();
165
166 Ok(PsiDriftProfile::new(features, config_clone))
167 }
168
169 pub fn compute_drift(
170 &mut self,
171 data: ConvertedData<'_>,
172 drift_profile: PsiDriftProfile,
173 ) -> Result<PsiDriftMap, DriftError> {
174 let (num_features, num_array, dtype, string_features, string_array) = data;
175 let dtype = dtype.unwrap_or("float32".to_string());
176
177 let mut features = num_features.clone();
178 features.extend(string_features.clone());
179
180 if let Some(string_array) = string_array {
181 if dtype == "float64" {
182 let string_array = self.convert_strings_to_numpy_f64(
183 string_features,
184 string_array,
185 drift_profile.clone(),
186 )?;
187
188 if num_array.is_some() {
189 let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
190 let concatenated =
191 concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
192 Ok(self.monitor.compute_drift(
193 &features,
194 &concatenated.view(),
195 &drift_profile,
196 )?)
197 } else {
198 Ok(self.monitor.compute_drift(
199 &features,
200 &string_array.view(),
201 &drift_profile,
202 )?)
203 }
204 } else {
205 let string_array = self.convert_strings_to_numpy_f32(
206 string_features,
207 string_array,
208 drift_profile.clone(),
209 )?;
210
211 if num_array.is_some() {
212 let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
213 let concatenated =
214 concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
215 Ok(self.monitor.compute_drift(
216 &features,
217 &concatenated.view(),
218 &drift_profile,
219 )?)
220 } else {
221 Ok(self.monitor.compute_drift(
222 &features,
223 &string_array.view(),
224 &drift_profile,
225 )?)
226 }
227 }
228 } else if dtype == "float64" {
229 let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
230 Ok(self
231 .monitor
232 .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
233 } else {
234 let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
235 Ok(self
236 .monitor
237 .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
238 }
239 }
240}