scouter_drift/psi/
monitor.rs

1use crate::error::DriftError;
2use crate::utils::CategoricalFeatureHelpers;
3use itertools::Itertools;
4use ndarray::prelude::*;
5use ndarray::Axis;
6use num_traits::{Float, FromPrimitive};
7use rayon::prelude::*;
8use scouter_types::psi::{
9    Bin, BinType, PsiDriftConfig, PsiDriftMap, PsiDriftProfile, PsiFeatureDriftProfile,
10};
11use std::collections::HashMap;
12
13#[derive(Default)]
14pub struct PsiMonitor {}
15
16impl CategoricalFeatureHelpers for PsiMonitor {}
17
18impl PsiMonitor {
19    pub fn new() -> Self {
20        PsiMonitor {}
21    }
22
23    fn compute_bin_count<F>(
24        &self,
25        array: &ArrayView<F, Ix1>,
26        lower_threshold: &f64,
27        upper_threshold: &f64,
28    ) -> usize
29    where
30        F: Float + FromPrimitive,
31        F: Into<f64>,
32    {
33        array
34            .iter()
35            .filter(|&&value| value.into() > *lower_threshold && value.into() <= *upper_threshold)
36            .count()
37    }
38
39    fn create_categorical_bins<F>(&self, column_vector: &ArrayView<F, Ix1>) -> Vec<Bin>
40    where
41        F: Float + FromPrimitive + Default + Sync,
42        F: Into<f64>,
43    {
44        let vector_len = column_vector.len() as f64;
45        let mut counts: HashMap<usize, usize> = HashMap::new();
46
47        for &value in column_vector.iter() {
48            let key = Into::<f64>::into(value) as usize;
49            *counts.entry(key).or_insert(0) += 1;
50        }
51
52        counts
53            .into_par_iter()
54            .map(|(id, count)| Bin {
55                id,
56                lower_limit: None,
57                upper_limit: None,
58                proportion: (count as f64) / vector_len,
59            })
60            .collect()
61    }
62
63    fn create_numeric_bins<F>(
64        &self,
65        column_vector: &ArrayView1<F>,
66        drift_config: &PsiDriftConfig,
67    ) -> Result<Vec<Bin>, DriftError>
68    where
69        F: Float + FromPrimitive + Default + Sync,
70        F: Into<f64>,
71    {
72        let edges = drift_config
73            .binning_strategy
74            .compute_edges(column_vector)
75            .map_err(|e| DriftError::BinningError(e.to_string()))?;
76
77        let bins: Vec<Bin> = (0..=edges.len())
78            .into_par_iter()
79            .map(|decile| {
80                let lower = if decile == 0 {
81                    F::neg_infinity()
82                } else {
83                    edges[decile - 1]
84                };
85                let upper = if decile == edges.len() {
86                    F::infinity()
87                } else {
88                    edges[decile]
89                };
90                let bin_count = self.compute_bin_count(column_vector, &lower.into(), &upper.into());
91                Bin {
92                    id: decile + 1,
93                    lower_limit: Some(lower.into()),
94                    upper_limit: Some(upper.into()),
95                    proportion: (bin_count as f64) / (column_vector.len() as f64),
96                }
97            })
98            .collect();
99        Ok(bins)
100    }
101
102    fn create_bins<F>(
103        &self,
104        feature_name: &String,
105        column_vector: &ArrayView<F, Ix1>,
106        drift_config: &PsiDriftConfig,
107    ) -> Result<(Vec<Bin>, BinType), DriftError>
108    where
109        F: Float + FromPrimitive + Default + Sync,
110        F: Into<f64>,
111    {
112        match &drift_config.categorical_features {
113            Some(features) if features.contains(feature_name) => {
114                // Process as categorical
115                Ok((
116                    self.create_categorical_bins(column_vector),
117                    BinType::Category,
118                ))
119            }
120            _ => {
121                // Process as continuous
122                Ok((
123                    self.create_numeric_bins(column_vector, drift_config)?,
124                    BinType::Numeric,
125                ))
126            }
127        }
128    }
129
130    fn create_psi_feature_drift_profile<F>(
131        &self,
132        feature_name: String,
133        column_vector: &ArrayView<F, Ix1>,
134        drift_config: &PsiDriftConfig,
135    ) -> Result<PsiFeatureDriftProfile, DriftError>
136    where
137        F: Float + Sync + FromPrimitive + Default,
138        F: Into<f64>,
139    {
140        let (bins, bin_type) = self.create_bins(&feature_name, column_vector, drift_config)?;
141
142        Ok(PsiFeatureDriftProfile {
143            id: feature_name,
144            bins,
145            timestamp: chrono::Utc::now(),
146            bin_type,
147        })
148    }
149
150    fn clean_column_vector<F>(column_vector: &ArrayView<F, Ix1>) -> Array<F, Ix1>
151    where
152        F: Float,
153    {
154        Array1::from(
155            column_vector
156                .iter()
157                .filter(|&&x| x.is_finite())
158                .cloned()
159                .collect::<Vec<F>>(),
160        )
161    }
162
163    pub fn create_2d_drift_profile<F>(
164        &self,
165        features: &[String],
166        array: &ArrayView2<F>,
167        drift_config: &PsiDriftConfig,
168    ) -> Result<PsiDriftProfile, DriftError>
169    where
170        F: Float + Sync + FromPrimitive + Default,
171        F: Into<f64>,
172    {
173        let mut psi_feature_drift_profiles = HashMap::new();
174
175        // Ensure that the number of features matches the number of columns in the array
176        assert_eq!(
177            features.len(),
178            array.shape()[1],
179            "Feature count must match column count."
180        );
181
182        let profile_vector = array
183            .axis_iter(Axis(1))
184            .zip(features)
185            .collect_vec()
186            .into_par_iter()
187            .map(|(column_vector, feature_name)| {
188                let clean_column_vector = Self::clean_column_vector(&column_vector);
189
190                if clean_column_vector.is_empty() {
191                    return Err(DriftError::EmptyArrayError(
192                        format!("PSI drift profile creation failure, unable to create psi feature drift profile for {feature_name}")
193                    ));
194                }
195
196                self.create_psi_feature_drift_profile(
197                    feature_name.to_string(),
198                    &clean_column_vector.view(),
199                    drift_config,
200                )
201            })
202            .collect::<Result<Vec<_>, _>>()?;
203
204        profile_vector
205            .into_iter()
206            .zip(features)
207            .for_each(|(profile, feature_name)| {
208                psi_feature_drift_profiles.insert(feature_name.clone(), profile);
209            });
210
211        Ok(PsiDriftProfile::new(
212            psi_feature_drift_profiles,
213            drift_config.clone(),
214        ))
215    }
216
217    fn compute_psi_proportion_pairs<F>(
218        &self,
219        column_vector: &ArrayView<F, Ix1>,
220        bin: &Bin,
221        feature_is_categorical: bool,
222    ) -> Result<(f64, f64), DriftError>
223    where
224        F: Float + FromPrimitive,
225        F: Into<f64>,
226    {
227        if feature_is_categorical {
228            let bin_count = column_vector
229                .iter()
230                .filter(|&&value| value.into() == bin.id as f64)
231                .count();
232            return Ok((
233                bin.proportion,
234                (bin_count as f64) / (column_vector.len() as f64),
235            ));
236        }
237
238        let bin_count = self.compute_bin_count(
239            column_vector,
240            &bin.lower_limit.unwrap(),
241            &bin.upper_limit.unwrap(),
242        );
243
244        Ok((
245            bin.proportion,
246            (bin_count as f64) / (column_vector.len() as f64),
247        ))
248    }
249
250    pub fn compute_psi(proportion_pairs: &[(f64, f64)]) -> f64 {
251        let epsilon = 1e-10;
252        proportion_pairs
253            .iter()
254            .map(|(p, q)| {
255                let p_adj = p + epsilon;
256                let q_adj = q + epsilon;
257                (p_adj - q_adj) * (p_adj / q_adj).ln()
258            })
259            .sum()
260    }
261
262    fn compute_feature_drift<F>(
263        &self,
264        column_vector: &ArrayView<F, Ix1>,
265        feature_drift_profile: &PsiFeatureDriftProfile,
266        feature_is_categorical: bool,
267    ) -> Result<f64, DriftError>
268    where
269        F: Float + Sync + FromPrimitive,
270        F: Into<f64>,
271    {
272        let bins = &feature_drift_profile.bins;
273        let feature_proportions: Vec<(f64, f64)> = bins
274            .into_par_iter()
275            .map(|bin| {
276                self.compute_psi_proportion_pairs(column_vector, bin, feature_is_categorical)
277            })
278            .collect::<Result<Vec<(f64, f64)>, DriftError>>()?;
279
280        Ok(PsiMonitor::compute_psi(&feature_proportions))
281    }
282
283    fn check_features<F>(
284        &self,
285        features: &[String],
286        array: &ArrayView2<F>,
287        drift_profile: &PsiDriftProfile,
288    ) -> Result<(), DriftError>
289    where
290        F: Float + Sync + FromPrimitive,
291        F: Into<f64>,
292    {
293        assert_eq!(
294            features.len(),
295            array.shape()[1],
296            "Feature count must match column count."
297        );
298
299        features
300            .iter()
301            .try_for_each(|feature_name| {
302                if !drift_profile.features.contains_key(feature_name) {
303                    // Collect all the keys from the drift profile into a comma-separated string
304                    let available_keys = drift_profile
305                        .features
306                        .keys()
307                        .cloned()
308                        .collect::<Vec<_>>()
309                        .join(", ");
310
311                    return Err(DriftError::RunTimeError(
312                        format!(
313                            "Feature mismatch, feature '{feature_name}' not found. Available features in the drift profile: {available_keys}"
314                        ),
315                    ));
316                }
317                Ok(())
318            })
319    }
320
321    pub fn compute_drift<F>(
322        &self,
323        features: &[String],
324        array: &ArrayView2<F>,
325        drift_profile: &PsiDriftProfile,
326    ) -> Result<PsiDriftMap, DriftError>
327    where
328        F: Float + Sync + FromPrimitive,
329        F: Into<f64>,
330    {
331        self.check_features(features, array, drift_profile)?;
332
333        let drift_values: Vec<_> = array
334            .axis_iter(Axis(1))
335            .zip(features)
336            .collect_vec()
337            .into_par_iter()
338            .map(|(column_vector, feature_name)| {
339                let feature_is_categorical = drift_profile
340                    .config
341                    .categorical_features
342                    .as_ref()
343                    .is_some_and(|features| features.contains(feature_name));
344                self.compute_feature_drift(
345                    &column_vector,
346                    drift_profile.features.get(feature_name).unwrap(),
347                    feature_is_categorical,
348                )
349            })
350            .collect::<Result<Vec<f64>, DriftError>>()?;
351
352        let mut psi_drift_features = HashMap::new();
353
354        drift_values
355            .into_iter()
356            .zip(features)
357            .for_each(|(drift_value, feature_name)| {
358                psi_drift_features.insert(feature_name.clone(), drift_value);
359            });
360
361        Ok(PsiDriftMap {
362            features: psi_drift_features,
363            name: drift_profile.config.name.clone(),
364            space: drift_profile.config.space.clone(),
365            version: drift_profile.config.version.clone(),
366        })
367    }
368}
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use ndarray::Array;
373    use ndarray_rand::rand_distr::Uniform;
374    use ndarray_rand::RandomExt;
375
376    #[test]
377    fn test_check_features_all_exist() {
378        let psi_monitor = PsiMonitor::default();
379
380        let array = Array::random((1030, 3), Uniform::new(0., 10.));
381
382        let features = vec![
383            "feature_1".to_string(),
384            "feature_2".to_string(),
385            "feature_3".to_string(),
386        ];
387
388        let profile = psi_monitor
389            .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
390            .unwrap();
391        assert_eq!(profile.features.len(), 3);
392
393        let result = psi_monitor.check_features(&features, &array.view(), &profile);
394
395        // Assert that the result is Ok
396        assert!(result.is_ok());
397    }
398
399    #[test]
400    fn test_compute_psi_basic() {
401        let proportions = vec![(0.3, 0.2), (0.4, 0.4), (0.3, 0.4)];
402
403        let result = PsiMonitor::compute_psi(&proportions);
404
405        // Manually compute expected PSI for this case
406        let expected_psi = (0.3 - 0.2) * (0.3 / 0.2).ln()
407            + (0.4 - 0.4) * (0.4 / 0.4).ln()
408            + (0.3 - 0.4) * (0.3 / 0.4).ln();
409
410        assert!((result - expected_psi).abs() < 1e-6);
411    }
412
413    #[test]
414    fn test_compute_bin_count() {
415        let psi_monitor = PsiMonitor::default();
416
417        let data = Array1::from_vec(vec![1.0, 2.5, 3.7, 5.0, 6.3, 8.1]);
418
419        let lower_threshold = 2.0;
420        let upper_threshold = 6.0;
421
422        let result =
423            psi_monitor.compute_bin_count(&data.view(), &lower_threshold, &upper_threshold);
424
425        // Check that it counts the correct number of elements within the bin
426        // In this case, 2.5, 3.7, and 5.0 should be counted (3 elements)
427        assert_eq!(result, 3);
428    }
429
430    #[test]
431    fn test_compute_psi_proportion_pairs_categorical() {
432        let psi_monitor = PsiMonitor::default();
433
434        let cat_vector = Array::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
435
436        let cat_zero_bin = Bin {
437            id: 0,
438            lower_limit: None,
439            upper_limit: None,
440            proportion: 0.4,
441        };
442
443        let (_, prod_proportion) = psi_monitor
444            .compute_psi_proportion_pairs(&cat_vector.view(), &cat_zero_bin, true)
445            .unwrap();
446
447        let expected_prod_proportion = 0.5;
448
449        assert!(
450            (prod_proportion - expected_prod_proportion).abs() < 1e-9,
451            "prod_proportion was expected to be 50%"
452        );
453    }
454
455    #[test]
456    fn test_compute_psi_proportion_pairs_non_categorical() {
457        let psi_monitor = PsiMonitor::default();
458
459        let vector = Array::from_vec(vec![
460            12.0, 11.0, 10.0, 1.0, 10.0, 21.0, 19.0, 12.0, 12.0, 23.0,
461        ]);
462
463        let bin = Bin {
464            id: 1,
465            lower_limit: Some(0.0),
466            upper_limit: Some(11.0),
467            proportion: 0.4,
468        };
469
470        let (_, prod_proportion) = psi_monitor
471            .compute_psi_proportion_pairs(&vector.view(), &bin, false)
472            .unwrap();
473
474        let expected_prod_proportion = 0.4;
475
476        assert!(
477            (prod_proportion - expected_prod_proportion).abs() < 1e-9,
478            "prod_proportion was expected to be 40%"
479        );
480    }
481
482    #[test]
483    fn test_create_bins_non_categorical() {
484        let psi_monitor = PsiMonitor::default();
485
486        let non_categorical_data = Array::from_vec(vec![
487            120.0, 1.0, 33.0, 71.0, 15.0, 59.0, 8.0, 62.0, 4.0, 21.0, 10.0, 2.0, 344.0, 437.0,
488            53.0, 39.0, 83.0, 6.0, 4.30, 2.0,
489        ]);
490
491        let result = psi_monitor.create_numeric_bins(
492            &ArrayView::from(&non_categorical_data),
493            &PsiDriftConfig::default(),
494        );
495
496        assert!(result.is_ok());
497        let bins = result.unwrap();
498        assert_eq!(bins.len(), 10);
499    }
500
501    #[test]
502    fn test_create_bins_categorical() {
503        let psi_monitor = PsiMonitor::default();
504
505        let categorical_data = Array::from_vec(vec![
506            1.0, 1.0, 2.0, 3.0, 2.0, 3.0, 2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 2.0, 3.0, 1.0,
507            1.0,
508        ]);
509
510        let bins = psi_monitor.create_categorical_bins(&ArrayView::from(&categorical_data));
511        assert_eq!(bins.len(), 3);
512    }
513
514    #[test]
515    fn test_create_2d_drift_profile() {
516        // create 2d array
517        let array = Array::random((1030, 3), Uniform::new(0., 10.));
518
519        // cast array to f32
520        let array = array.mapv(|x| x as f32);
521
522        let features = vec![
523            "feature_1".to_string(),
524            "feature_2".to_string(),
525            "feature_3".to_string(),
526        ];
527
528        let monitor = PsiMonitor::default();
529        let profile = monitor
530            .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
531            .unwrap();
532
533        assert_eq!(profile.features.len(), 3);
534    }
535
536    #[test]
537    fn test_compute_drift() {
538        // create 2d array
539        let array = Array::random((1030, 3), Uniform::new(0., 10.));
540
541        // cast array to f32
542        let array = array.mapv(|x| x as f32);
543
544        let features = vec![
545            "feature_1".to_string(),
546            "feature_2".to_string(),
547            "feature_3".to_string(),
548        ];
549
550        let monitor = PsiMonitor::default();
551
552        let profile = monitor
553            .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
554            .unwrap();
555
556        let drift_map = monitor
557            .compute_drift(&features, &array.view(), &profile)
558            .unwrap();
559
560        assert_eq!(drift_map.features.len(), 3);
561
562        // assert that the drift values are all 0.0
563        drift_map
564            .features
565            .values()
566            .for_each(|value| assert!(*value == 0.0));
567
568        // create new array that has drifted values
569        let mut new_array = Array::random((1030, 3), Uniform::new(0., 10.)).mapv(|x| x as f32);
570        new_array.slice_mut(s![.., 0]).mapv_inplace(|x| x + 0.01);
571
572        let new_drift_map = monitor
573            .compute_drift(&features, &new_array.view(), &profile)
574            .unwrap();
575
576        // assert that the drift values are all greater than 0.0
577        new_drift_map
578            .features
579            .values()
580            .for_each(|value| assert!(*value > 0.0));
581    }
582
583    #[test]
584    fn test_empty_array() {
585        let features = vec!["feature_1".to_string()];
586
587        let data = Array::<f64, _>::zeros((0, 1));
588
589        let monitor = PsiMonitor::default();
590
591        let profile =
592            monitor.create_2d_drift_profile(&features, &data.view(), &PsiDriftConfig::default());
593
594        assert!(profile.is_err());
595        match profile.unwrap_err() {
596            DriftError::EmptyArrayError(msg) => {
597                assert_eq!(msg, "PSI drift profile creation failure, unable to create psi feature drift profile for feature_1");
598            }
599            _ => panic!("Expected EmptyArrayError"),
600        }
601    }
602}