scouter_drift/psi/
monitor.rs

1use crate::error::DriftError;
2use crate::utils::CategoricalFeatureHelpers;
3use itertools::Itertools;
4use ndarray::prelude::*;
5use ndarray::Axis;
6use num_traits::{Float, FromPrimitive};
7use rayon::prelude::*;
8use scouter_types::psi::{
9    Bin, BinType, PsiDriftConfig, PsiDriftMap, PsiDriftProfile, PsiFeatureDriftProfile,
10};
11use std::collections::HashMap;
12
13#[derive(Default)]
14pub struct PsiMonitor {}
15
16impl CategoricalFeatureHelpers for PsiMonitor {}
17
18impl PsiMonitor {
19    pub fn new() -> Self {
20        PsiMonitor {}
21    }
22
23    fn compute_bin_count<F>(
24        &self,
25        array: &ArrayView<F, Ix1>,
26        lower_threshold: &f64,
27        upper_threshold: &f64,
28    ) -> usize
29    where
30        F: Float + FromPrimitive,
31        F: Into<f64>,
32    {
33        array
34            .iter()
35            .filter(|&&value| value.into() > *lower_threshold && value.into() <= *upper_threshold)
36            .count()
37    }
38
39    fn compute_deciles<F>(&self, column_vector: &ArrayView1<F>) -> Result<[F; 9], DriftError>
40    where
41        F: Float + Default,
42        F: Into<f64>,
43    {
44        // TODO: Explore using ndarray_stats quantiles instead of manual computation
45        if column_vector.len() < 10 {
46            return Err(DriftError::NotEnoughDecileValuesError);
47        }
48
49        let sorted_column_vector = column_vector
50            .iter()
51            .sorted_by(|a, b| a.partial_cmp(b).unwrap()) // Use partial_cmp and unwrap since we assume no NaNs
52            .cloned()
53            .collect_vec();
54
55        let n = sorted_column_vector.len();
56        let mut deciles: [F; 9] = Default::default();
57
58        for i in 1..=9 {
59            let index = ((i as f32 * (n as f32 - 1.0)) / 10.0).floor() as usize;
60            deciles[i - 1] = sorted_column_vector[index];
61        }
62        let decile_vec: [F; 9] = deciles
63            .to_vec()
64            .try_into()
65            .map_err(|_| DriftError::ConvertDecileToArray)?;
66
67        Ok(decile_vec)
68    }
69
70    fn create_categorical_bins<F>(&self, column_vector: &ArrayView<F, Ix1>) -> Vec<Bin>
71    where
72        F: Float + FromPrimitive + Default + Sync,
73        F: Into<f64>,
74    {
75        let vector_len = column_vector.len() as f64;
76        let mut counts: HashMap<usize, usize> = HashMap::new();
77
78        for &value in column_vector.iter() {
79            let key = Into::<f64>::into(value) as usize;
80            *counts.entry(key).or_insert(0) += 1;
81        }
82
83        counts
84            .into_par_iter()
85            .map(|(id, count)| Bin {
86                id,
87                lower_limit: None,
88                upper_limit: None,
89                proportion: (count as f64) / vector_len,
90            })
91            .collect()
92    }
93
94    fn create_numeric_bins<F>(&self, column_vector: &ArrayView1<F>) -> Result<Vec<Bin>, DriftError>
95    where
96        F: Float + FromPrimitive + Default + Sync,
97        F: Into<f64>,
98    {
99        let deciles = self.compute_deciles(column_vector)?;
100
101        let bins: Vec<Bin> = (0..=deciles.len())
102            .into_par_iter()
103            .map(|decile| {
104                let lower = if decile == 0 {
105                    F::neg_infinity()
106                } else {
107                    deciles[decile - 1]
108                };
109                let upper = if decile == deciles.len() {
110                    F::infinity()
111                } else {
112                    deciles[decile]
113                };
114                let bin_count = self.compute_bin_count(column_vector, &lower.into(), &upper.into());
115                Bin {
116                    id: decile + 1,
117                    lower_limit: Some(lower.into()),
118                    upper_limit: Some(upper.into()),
119                    proportion: (bin_count as f64) / (column_vector.len() as f64),
120                }
121            })
122            .collect();
123        Ok(bins)
124    }
125
126    fn create_bins<F>(
127        &self,
128        feature_name: &String,
129        column_vector: &ArrayView<F, Ix1>,
130        drift_config: &PsiDriftConfig,
131    ) -> Result<(Vec<Bin>, BinType), DriftError>
132    where
133        F: Float + FromPrimitive + Default + Sync,
134        F: Into<f64>,
135    {
136        match &drift_config.categorical_features {
137            Some(features) if features.contains(feature_name) => {
138                // Process as categorical
139                Ok((
140                    self.create_categorical_bins(column_vector),
141                    BinType::Category,
142                ))
143            }
144            _ => {
145                // Process as continuous
146                Ok((self.create_numeric_bins(column_vector)?, BinType::Numeric))
147            }
148        }
149    }
150
151    fn create_psi_feature_drift_profile<F>(
152        &self,
153        feature_name: String,
154        column_vector: &ArrayView<F, Ix1>,
155        drift_config: &PsiDriftConfig,
156    ) -> Result<PsiFeatureDriftProfile, DriftError>
157    where
158        F: Float + Sync + FromPrimitive + Default,
159        F: Into<f64>,
160    {
161        let (bins, bin_type) = self.create_bins(&feature_name, column_vector, drift_config)?;
162
163        Ok(PsiFeatureDriftProfile {
164            id: feature_name,
165            bins,
166            timestamp: chrono::Utc::now(),
167            bin_type,
168        })
169    }
170
171    pub fn create_2d_drift_profile<F>(
172        &self,
173        features: &[String],
174        array: &ArrayView2<F>,
175        drift_config: &PsiDriftConfig,
176    ) -> Result<PsiDriftProfile, DriftError>
177    where
178        F: Float + Sync + FromPrimitive + Default,
179        F: Into<f64>,
180    {
181        let mut psi_feature_drift_profiles = HashMap::new();
182
183        // Ensure that the number of features matches the number of columns in the array
184        assert_eq!(
185            features.len(),
186            array.shape()[1],
187            "Feature count must match column count."
188        );
189
190        let profile_vector = array
191            .axis_iter(Axis(1))
192            .zip(features)
193            .collect_vec()
194            .into_par_iter()
195            .map(|(column_vector, feature_name)| {
196                self.create_psi_feature_drift_profile(
197                    feature_name.to_string(),
198                    &column_vector,
199                    drift_config,
200                )
201            })
202            .collect::<Result<Vec<_>, _>>()?;
203
204        profile_vector
205            .into_iter()
206            .zip(features)
207            .for_each(|(profile, feature_name)| {
208                psi_feature_drift_profiles.insert(feature_name.clone(), profile);
209            });
210
211        Ok(PsiDriftProfile::new(
212            psi_feature_drift_profiles,
213            drift_config.clone(),
214            None,
215        ))
216    }
217
218    fn compute_psi_proportion_pairs<F>(
219        &self,
220        column_vector: &ArrayView<F, Ix1>,
221        bin: &Bin,
222        feature_is_categorical: bool,
223    ) -> Result<(f64, f64), DriftError>
224    where
225        F: Float + FromPrimitive,
226        F: Into<f64>,
227    {
228        if feature_is_categorical {
229            let bin_count = column_vector
230                .iter()
231                .filter(|&&value| value.into() == bin.id as f64)
232                .count();
233            return Ok((
234                bin.proportion,
235                (bin_count as f64) / (column_vector.len() as f64),
236            ));
237        }
238
239        let bin_count = self.compute_bin_count(
240            column_vector,
241            &bin.lower_limit.unwrap(),
242            &bin.upper_limit.unwrap(),
243        );
244
245        Ok((
246            bin.proportion,
247            (bin_count as f64) / (column_vector.len() as f64),
248        ))
249    }
250
251    pub fn compute_psi(proportion_pairs: &[(f64, f64)]) -> f64 {
252        let epsilon = 1e-10;
253        proportion_pairs
254            .iter()
255            .map(|(p, q)| {
256                let p_adj = p + epsilon;
257                let q_adj = q + epsilon;
258                (p_adj - q_adj) * (p_adj / q_adj).ln()
259            })
260            .sum()
261    }
262
263    fn compute_feature_drift<F>(
264        &self,
265        column_vector: &ArrayView<F, Ix1>,
266        feature_drift_profile: &PsiFeatureDriftProfile,
267        feature_is_categorical: bool,
268    ) -> Result<f64, DriftError>
269    where
270        F: Float + Sync + FromPrimitive,
271        F: Into<f64>,
272    {
273        let bins = &feature_drift_profile.bins;
274        let feature_proportions: Vec<(f64, f64)> = bins
275            .into_par_iter()
276            .map(|bin| {
277                self.compute_psi_proportion_pairs(column_vector, bin, feature_is_categorical)
278            })
279            .collect::<Result<Vec<(f64, f64)>, DriftError>>()?;
280
281        Ok(PsiMonitor::compute_psi(&feature_proportions))
282    }
283
284    fn check_features<F>(
285        &self,
286        features: &[String],
287        array: &ArrayView2<F>,
288        drift_profile: &PsiDriftProfile,
289    ) -> Result<(), DriftError>
290    where
291        F: Float + Sync + FromPrimitive,
292        F: Into<f64>,
293    {
294        assert_eq!(
295            features.len(),
296            array.shape()[1],
297            "Feature count must match column count."
298        );
299
300        features
301            .iter()
302            .try_for_each(|feature_name| {
303                if !drift_profile.features.contains_key(feature_name) {
304                    // Collect all the keys from the drift profile into a comma-separated string
305                    let available_keys = drift_profile
306                        .features
307                        .keys()
308                        .cloned()
309                        .collect::<Vec<_>>()
310                        .join(", ");
311
312                    return Err(DriftError::RunTimeError(
313                        format!(
314                            "Feature mismatch, feature '{feature_name}' not found. Available features in the drift profile: {available_keys}"
315                        ),
316                    ));
317                }
318                Ok(())
319            })
320    }
321
322    pub fn compute_drift<F>(
323        &self,
324        features: &[String],
325        array: &ArrayView2<F>,
326        drift_profile: &PsiDriftProfile,
327    ) -> Result<PsiDriftMap, DriftError>
328    where
329        F: Float + Sync + FromPrimitive,
330        F: Into<f64>,
331    {
332        self.check_features(features, array, drift_profile)?;
333
334        let drift_values: Vec<_> = array
335            .axis_iter(Axis(1))
336            .zip(features)
337            .collect_vec()
338            .into_par_iter()
339            .map(|(column_vector, feature_name)| {
340                let feature_is_categorical = drift_profile
341                    .config
342                    .categorical_features
343                    .as_ref()
344                    .is_some_and(|features| features.contains(feature_name));
345                self.compute_feature_drift(
346                    &column_vector,
347                    drift_profile.features.get(feature_name).unwrap(),
348                    feature_is_categorical,
349                )
350            })
351            .collect::<Result<Vec<f64>, DriftError>>()?;
352
353        let mut psi_drift_features = HashMap::new();
354
355        drift_values
356            .into_iter()
357            .zip(features)
358            .for_each(|(drift_value, feature_name)| {
359                psi_drift_features.insert(feature_name.clone(), drift_value);
360            });
361
362        Ok(PsiDriftMap {
363            features: psi_drift_features,
364            name: drift_profile.config.name.clone(),
365            space: drift_profile.config.space.clone(),
366            version: drift_profile.config.version.clone(),
367        })
368    }
369}
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use ndarray::Array;
374    use ndarray_rand::rand_distr::Uniform;
375    use ndarray_rand::RandomExt;
376
377    #[test]
378    fn test_check_features_all_exist() {
379        let psi_monitor = PsiMonitor::default();
380
381        let array = Array::random((1030, 3), Uniform::new(0., 10.));
382
383        let features = vec![
384            "feature_1".to_string(),
385            "feature_2".to_string(),
386            "feature_3".to_string(),
387        ];
388
389        let profile = psi_monitor
390            .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
391            .unwrap();
392        assert_eq!(profile.features.len(), 3);
393
394        let result = psi_monitor.check_features(&features, &array.view(), &profile);
395
396        // Assert that the result is Ok
397        assert!(result.is_ok());
398    }
399
400    #[test]
401    fn test_compute_psi_basic() {
402        let proportions = vec![(0.3, 0.2), (0.4, 0.4), (0.3, 0.4)];
403
404        let result = PsiMonitor::compute_psi(&proportions);
405
406        // Manually compute expected PSI for this case
407        let expected_psi = (0.3 - 0.2) * (0.3 / 0.2).ln()
408            + (0.4 - 0.4) * (0.4 / 0.4).ln()
409            + (0.3 - 0.4) * (0.3 / 0.4).ln();
410
411        assert!((result - expected_psi).abs() < 1e-6);
412    }
413
414    #[test]
415    fn test_compute_bin_count() {
416        let psi_monitor = PsiMonitor::default();
417
418        let data = Array1::from_vec(vec![1.0, 2.5, 3.7, 5.0, 6.3, 8.1]);
419
420        let lower_threshold = 2.0;
421        let upper_threshold = 6.0;
422
423        let result =
424            psi_monitor.compute_bin_count(&data.view(), &lower_threshold, &upper_threshold);
425
426        // Check that it counts the correct number of elements within the bin
427        // In this case, 2.5, 3.7, and 5.0 should be counted (3 elements)
428        assert_eq!(result, 3);
429    }
430
431    #[test]
432    fn test_compute_psi_proportion_pairs_categorical() {
433        let psi_monitor = PsiMonitor::default();
434
435        let cat_vector = Array::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
436
437        let cat_zero_bin = Bin {
438            id: 0,
439            lower_limit: None,
440            upper_limit: None,
441            proportion: 0.4,
442        };
443
444        let (_, prod_proportion) = psi_monitor
445            .compute_psi_proportion_pairs(&cat_vector.view(), &cat_zero_bin, true)
446            .unwrap();
447
448        let expected_prod_proportion = 0.5;
449
450        assert!(
451            (prod_proportion - expected_prod_proportion).abs() < 1e-9,
452            "prod_proportion was expected to be 50%"
453        );
454    }
455
456    #[test]
457    fn test_compute_psi_proportion_pairs_non_categorical() {
458        let psi_monitor = PsiMonitor::default();
459
460        let vector = Array::from_vec(vec![
461            12.0, 11.0, 10.0, 1.0, 10.0, 21.0, 19.0, 12.0, 12.0, 23.0,
462        ]);
463
464        let bin = Bin {
465            id: 1,
466            lower_limit: Some(0.0),
467            upper_limit: Some(11.0),
468            proportion: 0.4,
469        };
470
471        let (_, prod_proportion) = psi_monitor
472            .compute_psi_proportion_pairs(&vector.view(), &bin, false)
473            .unwrap();
474
475        let expected_prod_proportion = 0.4;
476
477        assert!(
478            (prod_proportion - expected_prod_proportion).abs() < 1e-9,
479            "prod_proportion was expected to be 40%"
480        );
481    }
482
483    #[test]
484    fn test_compute_deciles_with_unsorted_input() {
485        let psi_monitor = PsiMonitor::default();
486
487        let unsorted_vector = Array::from_vec(vec![
488            120.0, 1.0, 33.0, 71.0, 15.0, 59.0, 8.0, 62.0, 4.0, 21.0, 10.0, 2.0, 344.0, 437.0,
489            53.0, 39.0, 83.0, 6.0, 4.30, 2.0,
490        ]);
491
492        let column_view = unsorted_vector.view();
493
494        let result = psi_monitor.compute_deciles(&column_view);
495
496        let expected_deciles: [f64; 9] = [2.0, 4.0, 6.0, 10.0, 21.0, 39.0, 59.0, 71.0, 120.0];
497
498        assert_eq!(
499            result.unwrap().as_ref(),
500            expected_deciles.as_ref(),
501            "Deciles computed incorrectly for unsorted input"
502        );
503    }
504
505    #[test]
506    fn test_create_bins_non_categorical() {
507        let psi_monitor = PsiMonitor::default();
508
509        let non_categorical_data = Array::from_vec(vec![
510            120.0, 1.0, 33.0, 71.0, 15.0, 59.0, 8.0, 62.0, 4.0, 21.0, 10.0, 2.0, 344.0, 437.0,
511            53.0, 39.0, 83.0, 6.0, 4.30, 2.0,
512        ]);
513
514        let result = psi_monitor.create_numeric_bins(&ArrayView::from(&non_categorical_data));
515
516        assert!(result.is_ok());
517        let bins = result.unwrap();
518        assert_eq!(bins.len(), 10);
519    }
520
521    #[test]
522    fn test_create_bins_categorical() {
523        let psi_monitor = PsiMonitor::default();
524
525        let categorical_data = Array::from_vec(vec![
526            1.0, 1.0, 2.0, 3.0, 2.0, 3.0, 2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 2.0, 3.0, 1.0,
527            1.0,
528        ]);
529
530        let bins = psi_monitor.create_categorical_bins(&ArrayView::from(&categorical_data));
531        assert_eq!(bins.len(), 3);
532    }
533
534    #[test]
535    fn test_create_2d_drift_profile() {
536        // create 2d array
537        let array = Array::random((1030, 3), Uniform::new(0., 10.));
538
539        // cast array to f32
540        let array = array.mapv(|x| x as f32);
541
542        let features = vec![
543            "feature_1".to_string(),
544            "feature_2".to_string(),
545            "feature_3".to_string(),
546        ];
547
548        let monitor = PsiMonitor::default();
549        let profile = monitor
550            .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
551            .unwrap();
552
553        assert_eq!(profile.features.len(), 3);
554    }
555
556    #[test]
557    fn test_compute_drift() {
558        // create 2d array
559        let array = Array::random((1030, 3), Uniform::new(0., 10.));
560
561        // cast array to f32
562        let array = array.mapv(|x| x as f32);
563
564        let features = vec![
565            "feature_1".to_string(),
566            "feature_2".to_string(),
567            "feature_3".to_string(),
568        ];
569
570        let monitor = PsiMonitor::default();
571
572        let profile = monitor
573            .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
574            .unwrap();
575
576        let drift_map = monitor
577            .compute_drift(&features, &array.view(), &profile)
578            .unwrap();
579
580        assert_eq!(drift_map.features.len(), 3);
581
582        // assert that the drift values are all 0.0
583        drift_map
584            .features
585            .values()
586            .for_each(|value| assert!(*value == 0.0));
587
588        // create new array that has drifted values
589        let mut new_array = Array::random((1030, 3), Uniform::new(0., 10.)).mapv(|x| x as f32);
590        new_array.slice_mut(s![.., 0]).mapv_inplace(|x| x + 0.01);
591
592        let new_drift_map = monitor
593            .compute_drift(&features, &new_array.view(), &profile)
594            .unwrap();
595
596        // assert that the drift values are all greater than 0.0
597        new_drift_map
598            .features
599            .values()
600            .for_each(|value| assert!(*value > 0.0));
601    }
602}