scouter_client/drifter/
psi.rs

1use crate::data_utils::{convert_array_type, ConvertedData};
2use ndarray::{concatenate, Array2, Axis};
3use num_traits::{Float, FromPrimitive};
4use numpy::PyReadonlyArray2;
5use scouter_drift::error::DriftError;
6use scouter_drift::{psi::PsiMonitor, CategoricalFeatureHelpers};
7use scouter_types::psi::{PsiDriftConfig, PsiDriftMap, PsiDriftProfile};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10use tracing::instrument;
11#[derive(Default)]
12pub struct PsiDrifter {
13    monitor: PsiMonitor,
14}
15
16impl PsiDrifter {
17    pub fn new() -> Self {
18        let monitor = PsiMonitor::new();
19        PsiDrifter { monitor }
20    }
21
22    pub fn convert_strings_to_numpy_f32(
23        &mut self,
24        features: Vec<String>,
25        array: Vec<Vec<String>>,
26        drift_profile: PsiDriftProfile,
27    ) -> Result<Array2<f32>, DriftError> {
28        let array = self.monitor.convert_strings_to_ndarray_f32(
29            &features,
30            &array,
31            &drift_profile.config.feature_map,
32        )?;
33
34        Ok(array)
35    }
36
37    pub fn convert_strings_to_numpy_f64(
38        &mut self,
39        features: Vec<String>,
40        array: Vec<Vec<String>>,
41        drift_profile: PsiDriftProfile,
42    ) -> Result<Array2<f64>, DriftError> {
43        let array = self.monitor.convert_strings_to_ndarray_f64(
44            &features,
45            &array,
46            &drift_profile.config.feature_map,
47        )?;
48
49        Ok(array)
50    }
51
52    #[instrument(skip_all)]
53    pub fn create_string_drift_profile(
54        &mut self,
55        array: Vec<Vec<String>>,
56        features: Vec<String>,
57        drift_config: Arc<RwLock<PsiDriftConfig>>,
58    ) -> Result<PsiDriftProfile, DriftError> {
59        let feature_map = self.monitor.create_feature_map(&features, &array)?;
60
61        drift_config
62            .write()
63            .unwrap()
64            .update_feature_map(feature_map.clone());
65
66        let array = self
67            .monitor
68            .convert_strings_to_ndarray_f32(&features, &array, &feature_map)?;
69
70        let profile = self.monitor.create_2d_drift_profile(
71            &features,
72            &array.view(),
73            &drift_config.read().unwrap(),
74        )?;
75
76        Ok(profile)
77    }
78
79    pub fn create_numeric_drift_profile<F>(
80        &mut self,
81        array: PyReadonlyArray2<F>,
82        features: Vec<String>,
83        drift_config: &PsiDriftConfig,
84    ) -> Result<PsiDriftProfile, DriftError>
85    where
86        F: Float + Sync + FromPrimitive + Default,
87        F: Into<f64>,
88        F: numpy::Element,
89    {
90        let array = array.as_array();
91
92        let profile = self
93            .monitor
94            .create_2d_drift_profile(&features, &array, drift_config)?;
95
96        Ok(profile)
97    }
98
99    pub fn create_drift_profile(
100        &mut self,
101        data: ConvertedData<'_>,
102        config: Arc<RwLock<PsiDriftConfig>>,
103    ) -> Result<PsiDriftProfile, DriftError> {
104        let (num_features, num_array, dtype, string_features, string_array) = data;
105
106        // Validate categorical_features
107        {
108            let read_config = config.read().unwrap();
109            if let Some(categorical_features) = read_config.categorical_features.as_ref() {
110                if let Some(missing_feature) = categorical_features
111                    .iter()
112                    .find(|&key| !num_features.contains(key) && !string_features.contains(key))
113                {
114                    return Err(DriftError::CategoricalFeatureMissingError(
115                        missing_feature.to_string(),
116                    ));
117                }
118            }
119        }
120
121        let mut features = HashMap::new();
122
123        if let Some(string_array) = string_array {
124            let profile =
125                self.create_string_drift_profile(string_array, string_features, config.clone())?;
126            features.extend(profile.features);
127        }
128
129        if let Some(num_array) = num_array {
130            let dtype = dtype.unwrap();
131            let drift_profile = {
132                let read_config = config.read().unwrap();
133                if dtype == "float64" {
134                    let array = convert_array_type::<f64>(num_array, &dtype)?;
135                    self.create_numeric_drift_profile(array, num_features, &read_config)?
136                } else {
137                    let array = convert_array_type::<f32>(num_array, &dtype)?;
138                    self.create_numeric_drift_profile(array, num_features, &read_config)?
139                }
140            };
141            features.extend(drift_profile.features);
142        }
143
144        // if config.features_to_monitor is empty, set it to all features
145        {
146            let mut write_config = config.write().unwrap();
147            if write_config.alert_config.features_to_monitor.is_empty() {
148                write_config.alert_config.features_to_monitor = features.keys().cloned().collect();
149            }
150
151            // Validate features_to_monitor
152            if let Some(missing_feature) = write_config
153                .alert_config
154                .features_to_monitor
155                .iter()
156                .find(|&key| !features.contains_key(key))
157            {
158                return Err(DriftError::FeatureToMonitorMissingError(
159                    missing_feature.to_string(),
160                ));
161            }
162        }
163
164        let config_clone = config.read().unwrap().clone();
165
166        Ok(PsiDriftProfile::new(features, config_clone))
167    }
168
169    pub fn compute_drift(
170        &mut self,
171        data: ConvertedData<'_>,
172        drift_profile: PsiDriftProfile,
173    ) -> Result<PsiDriftMap, DriftError> {
174        let (num_features, num_array, dtype, string_features, string_array) = data;
175        let dtype = dtype.unwrap_or("float32".to_string());
176
177        let mut features = num_features.clone();
178        features.extend(string_features.clone());
179
180        if let Some(string_array) = string_array {
181            if dtype == "float64" {
182                let string_array = self.convert_strings_to_numpy_f64(
183                    string_features,
184                    string_array,
185                    drift_profile.clone(),
186                )?;
187
188                if num_array.is_some() {
189                    let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
190                    let concatenated =
191                        concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
192                    Ok(self.monitor.compute_drift(
193                        &features,
194                        &concatenated.view(),
195                        &drift_profile,
196                    )?)
197                } else {
198                    Ok(self.monitor.compute_drift(
199                        &features,
200                        &string_array.view(),
201                        &drift_profile,
202                    )?)
203                }
204            } else {
205                let string_array = self.convert_strings_to_numpy_f32(
206                    string_features,
207                    string_array,
208                    drift_profile.clone(),
209                )?;
210
211                if num_array.is_some() {
212                    let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
213                    let concatenated =
214                        concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
215                    Ok(self.monitor.compute_drift(
216                        &features,
217                        &concatenated.view(),
218                        &drift_profile,
219                    )?)
220                } else {
221                    Ok(self.monitor.compute_drift(
222                        &features,
223                        &string_array.view(),
224                        &drift_profile,
225                    )?)
226                }
227            }
228        } else if dtype == "float64" {
229            let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
230            Ok(self
231                .monitor
232                .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
233        } else {
234            let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
235            Ok(self
236                .monitor
237                .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
238        }
239    }
240}