1use crate::data_utils::{convert_array_type, ConvertedData};
2use ndarray::{concatenate, Array2, Axis};
3use num_traits::{Float, FromPrimitive};
4use numpy::PyReadonlyArray2;
5use scouter_drift::error::DriftError;
6use scouter_drift::{psi::PsiMonitor, CategoricalFeatureHelpers};
7use scouter_types::psi::{PsiDriftConfig, PsiDriftMap, PsiDriftProfile};
8use std::collections::HashMap;
9use tracing::instrument;
10
11#[derive(Default)]
12pub struct PsiDrifter {
13 monitor: PsiMonitor,
14}
15
16impl PsiDrifter {
17 pub fn new() -> Self {
18 let monitor = PsiMonitor::new();
19 PsiDrifter { monitor }
20 }
21
22 pub fn convert_strings_to_numpy_f32(
23 &mut self,
24 features: Vec<String>,
25 array: Vec<Vec<String>>,
26 drift_profile: PsiDriftProfile,
27 ) -> Result<Array2<f32>, DriftError> {
28 let array = self.monitor.convert_strings_to_ndarray_f32(
29 &features,
30 &array,
31 &drift_profile.config.feature_map,
32 )?;
33
34 Ok(array)
35 }
36
37 pub fn convert_strings_to_numpy_f64(
38 &mut self,
39 features: Vec<String>,
40 array: Vec<Vec<String>>,
41 drift_profile: PsiDriftProfile,
42 ) -> Result<Array2<f64>, DriftError> {
43 let array = self.monitor.convert_strings_to_ndarray_f64(
44 &features,
45 &array,
46 &drift_profile.config.feature_map,
47 )?;
48
49 Ok(array)
50 }
51
52 #[instrument(skip_all)]
53 pub fn create_string_drift_profile(
54 &mut self,
55 array: Vec<Vec<String>>,
56 features: Vec<String>,
57 mut drift_config: PsiDriftConfig,
58 ) -> Result<PsiDriftProfile, DriftError> {
59 let feature_map = self.monitor.create_feature_map(&features, &array)?;
60
61 drift_config.update_feature_map(feature_map.clone());
62
63 let array = self
64 .monitor
65 .convert_strings_to_ndarray_f32(&features, &array, &feature_map)?;
66
67 let profile =
68 self.monitor
69 .create_2d_drift_profile(&features, &array.view(), &drift_config)?;
70
71 Ok(profile)
72 }
73
74 pub fn create_numeric_drift_profile<F>(
75 &mut self,
76 array: PyReadonlyArray2<F>,
77 features: Vec<String>,
78 drift_config: PsiDriftConfig,
79 ) -> Result<PsiDriftProfile, DriftError>
80 where
81 F: Float + Sync + FromPrimitive + Default + PartialOrd,
82 F: Into<f64>,
83 F: numpy::Element,
84 {
85 let array = array.as_array();
86
87 let profile = self
88 .monitor
89 .create_2d_drift_profile(&features, &array, &drift_config)?;
90
91 Ok(profile)
92 }
93
94 pub fn create_drift_profile(
95 &mut self,
96 data: ConvertedData<'_>,
97 config: PsiDriftConfig,
98 ) -> Result<PsiDriftProfile, DriftError> {
99 let (num_features, num_array, dtype, string_features, string_array) = data;
100
101 let mut final_config = config.clone();
102
103 if let Some(categorical_features) = final_config.categorical_features.as_ref() {
105 if let Some(missing_feature) = categorical_features
107 .iter()
108 .find(|&key| !num_features.contains(key) && !string_features.contains(key))
109 {
110 return Err(DriftError::CategoricalFeatureMissingError(
111 missing_feature.to_string(),
112 ));
113 }
114 }
115
116 let mut features = HashMap::new();
117
118 if let Some(string_array) = string_array {
119 let profile = self.create_string_drift_profile(
120 string_array,
121 string_features,
122 final_config.clone(),
123 )?;
124 final_config.feature_map = profile.config.feature_map.clone();
125 features.extend(profile.features);
126 }
127
128 if let Some(num_array) = num_array {
129 let dtype = dtype.unwrap();
130 let drift_profile = if dtype == "float64" {
131 let array = convert_array_type::<f64>(num_array, &dtype)?;
132 self.create_numeric_drift_profile(array, num_features, final_config.clone())?
133 } else {
134 let array = convert_array_type::<f32>(num_array, &dtype)?;
135 self.create_numeric_drift_profile(array, num_features, final_config.clone())?
136 };
137 features.extend(drift_profile.features);
138 }
139
140 if final_config.alert_config.features_to_monitor.is_empty() {
142 final_config.alert_config.features_to_monitor = features.keys().cloned().collect();
143 }
144
145 if let Some(missing_feature) = final_config
147 .alert_config
148 .features_to_monitor
149 .iter()
150 .find(|&key| !features.contains_key(key))
151 {
152 return Err(DriftError::FeatureToMonitorMissingError(
153 missing_feature.to_string(),
154 ));
155 }
156
157 Ok(PsiDriftProfile::new(features, final_config, None))
158 }
159
160 pub fn compute_drift(
161 &mut self,
162 data: ConvertedData<'_>,
163 drift_profile: PsiDriftProfile,
164 ) -> Result<PsiDriftMap, DriftError> {
165 let (num_features, num_array, dtype, string_features, string_array) = data;
166 let dtype = dtype.unwrap_or("float32".to_string());
167
168 let mut features = num_features.clone();
169 features.extend(string_features.clone());
170
171 if let Some(string_array) = string_array {
172 if dtype == "float64" {
173 let string_array = self.convert_strings_to_numpy_f64(
174 string_features,
175 string_array,
176 drift_profile.clone(),
177 )?;
178
179 if num_array.is_some() {
180 let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
181 let concatenated =
182 concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
183 Ok(self.monitor.compute_drift(
184 &features,
185 &concatenated.view(),
186 &drift_profile,
187 )?)
188 } else {
189 Ok(self.monitor.compute_drift(
190 &features,
191 &string_array.view(),
192 &drift_profile,
193 )?)
194 }
195 } else {
196 let string_array = self.convert_strings_to_numpy_f32(
197 string_features,
198 string_array,
199 drift_profile.clone(),
200 )?;
201
202 if num_array.is_some() {
203 let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
204 let concatenated =
205 concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
206 Ok(self.monitor.compute_drift(
207 &features,
208 &concatenated.view(),
209 &drift_profile,
210 )?)
211 } else {
212 Ok(self.monitor.compute_drift(
213 &features,
214 &string_array.view(),
215 &drift_profile,
216 )?)
217 }
218 }
219 } else if dtype == "float64" {
220 let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
221 Ok(self
222 .monitor
223 .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
224 } else {
225 let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
226 Ok(self
227 .monitor
228 .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
229 }
230 }
231}