scouter_client/drifter/
psi.rs

1use crate::data_utils::{convert_array_type, ConvertedData};
2use ndarray::{concatenate, Array2, Axis};
3use num_traits::{Float, FromPrimitive};
4use numpy::PyReadonlyArray2;
5use scouter_drift::error::DriftError;
6use scouter_drift::{psi::PsiMonitor, CategoricalFeatureHelpers};
7use scouter_types::psi::{PsiDriftConfig, PsiDriftMap, PsiDriftProfile};
8use std::collections::HashMap;
9use tracing::instrument;
10
11#[derive(Default)]
12pub struct PsiDrifter {
13    monitor: PsiMonitor,
14}
15
16impl PsiDrifter {
17    pub fn new() -> Self {
18        let monitor = PsiMonitor::new();
19        PsiDrifter { monitor }
20    }
21
22    pub fn convert_strings_to_numpy_f32(
23        &mut self,
24        features: Vec<String>,
25        array: Vec<Vec<String>>,
26        drift_profile: PsiDriftProfile,
27    ) -> Result<Array2<f32>, DriftError> {
28        let array = self.monitor.convert_strings_to_ndarray_f32(
29            &features,
30            &array,
31            &drift_profile.config.feature_map,
32        )?;
33
34        Ok(array)
35    }
36
37    pub fn convert_strings_to_numpy_f64(
38        &mut self,
39        features: Vec<String>,
40        array: Vec<Vec<String>>,
41        drift_profile: PsiDriftProfile,
42    ) -> Result<Array2<f64>, DriftError> {
43        let array = self.monitor.convert_strings_to_ndarray_f64(
44            &features,
45            &array,
46            &drift_profile.config.feature_map,
47        )?;
48
49        Ok(array)
50    }
51
52    #[instrument(skip_all)]
53    pub fn create_string_drift_profile(
54        &mut self,
55        array: Vec<Vec<String>>,
56        features: Vec<String>,
57        mut drift_config: PsiDriftConfig,
58    ) -> Result<PsiDriftProfile, DriftError> {
59        let feature_map = self.monitor.create_feature_map(&features, &array)?;
60
61        drift_config.update_feature_map(feature_map.clone());
62
63        let array = self
64            .monitor
65            .convert_strings_to_ndarray_f32(&features, &array, &feature_map)?;
66
67        let profile =
68            self.monitor
69                .create_2d_drift_profile(&features, &array.view(), &drift_config)?;
70
71        Ok(profile)
72    }
73
74    pub fn create_numeric_drift_profile<F>(
75        &mut self,
76        array: PyReadonlyArray2<F>,
77        features: Vec<String>,
78        drift_config: PsiDriftConfig,
79    ) -> Result<PsiDriftProfile, DriftError>
80    where
81        F: Float + Sync + FromPrimitive + Default + PartialOrd,
82        F: Into<f64>,
83        F: numpy::Element,
84    {
85        let array = array.as_array();
86
87        let profile = self
88            .monitor
89            .create_2d_drift_profile(&features, &array, &drift_config)?;
90
91        Ok(profile)
92    }
93
94    pub fn create_drift_profile(
95        &mut self,
96        data: ConvertedData<'_>,
97        config: PsiDriftConfig,
98    ) -> Result<PsiDriftProfile, DriftError> {
99        let (num_features, num_array, dtype, string_features, string_array) = data;
100
101        let mut final_config = config.clone();
102
103        // Validate categorical_features
104        if let Some(categorical_features) = final_config.categorical_features.as_ref() {
105            // fail if the specified categorical features are not in the num_features or string_features
106            if let Some(missing_feature) = categorical_features
107                .iter()
108                .find(|&key| !num_features.contains(key) && !string_features.contains(key))
109            {
110                return Err(DriftError::CategoricalFeatureMissingError(
111                    missing_feature.to_string(),
112                ));
113            }
114        }
115
116        let mut features = HashMap::new();
117
118        if let Some(string_array) = string_array {
119            let profile = self.create_string_drift_profile(
120                string_array,
121                string_features,
122                final_config.clone(),
123            )?;
124            final_config.feature_map = profile.config.feature_map.clone();
125            features.extend(profile.features);
126        }
127
128        if let Some(num_array) = num_array {
129            let dtype = dtype.unwrap();
130            let drift_profile = if dtype == "float64" {
131                let array = convert_array_type::<f64>(num_array, &dtype)?;
132                self.create_numeric_drift_profile(array, num_features, final_config.clone())?
133            } else {
134                let array = convert_array_type::<f32>(num_array, &dtype)?;
135                self.create_numeric_drift_profile(array, num_features, final_config.clone())?
136            };
137            features.extend(drift_profile.features);
138        }
139
140        // if config.features_to_monitor is empty, set it to all features
141        if final_config.alert_config.features_to_monitor.is_empty() {
142            final_config.alert_config.features_to_monitor = features.keys().cloned().collect();
143        }
144
145        // Validate features_to_monitor
146        if let Some(missing_feature) = final_config
147            .alert_config
148            .features_to_monitor
149            .iter()
150            .find(|&key| !features.contains_key(key))
151        {
152            return Err(DriftError::FeatureToMonitorMissingError(
153                missing_feature.to_string(),
154            ));
155        }
156
157        Ok(PsiDriftProfile::new(features, final_config, None))
158    }
159
160    pub fn compute_drift(
161        &mut self,
162        data: ConvertedData<'_>,
163        drift_profile: PsiDriftProfile,
164    ) -> Result<PsiDriftMap, DriftError> {
165        let (num_features, num_array, dtype, string_features, string_array) = data;
166        let dtype = dtype.unwrap_or("float32".to_string());
167
168        let mut features = num_features.clone();
169        features.extend(string_features.clone());
170
171        if let Some(string_array) = string_array {
172            if dtype == "float64" {
173                let string_array = self.convert_strings_to_numpy_f64(
174                    string_features,
175                    string_array,
176                    drift_profile.clone(),
177                )?;
178
179                if num_array.is_some() {
180                    let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
181                    let concatenated =
182                        concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
183                    Ok(self.monitor.compute_drift(
184                        &features,
185                        &concatenated.view(),
186                        &drift_profile,
187                    )?)
188                } else {
189                    Ok(self.monitor.compute_drift(
190                        &features,
191                        &string_array.view(),
192                        &drift_profile,
193                    )?)
194                }
195            } else {
196                let string_array = self.convert_strings_to_numpy_f32(
197                    string_features,
198                    string_array,
199                    drift_profile.clone(),
200                )?;
201
202                if num_array.is_some() {
203                    let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
204                    let concatenated =
205                        concatenate(Axis(1), &[array.as_array(), string_array.view()])?;
206                    Ok(self.monitor.compute_drift(
207                        &features,
208                        &concatenated.view(),
209                        &drift_profile,
210                    )?)
211                } else {
212                    Ok(self.monitor.compute_drift(
213                        &features,
214                        &string_array.view(),
215                        &drift_profile,
216                    )?)
217                }
218            }
219        } else if dtype == "float64" {
220            let array = convert_array_type::<f64>(num_array.unwrap(), &dtype)?;
221            Ok(self
222                .monitor
223                .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
224        } else {
225            let array = convert_array_type::<f32>(num_array.unwrap(), &dtype)?;
226            Ok(self
227                .monitor
228                .compute_drift(&num_features, &array.as_array(), &drift_profile)?)
229        }
230    }
231}