scouter_drift/psi/
monitor.rs

1use crate::error::DriftError;
2use crate::utils::CategoricalFeatureHelpers;
3use itertools::Itertools;
4use ndarray::prelude::*;
5use ndarray::Axis;
6use num_traits::{Float, FromPrimitive};
7use rayon::prelude::*;
8use scouter_types::psi::{
9    Bin, BinType, PsiDriftConfig, PsiDriftMap, PsiDriftProfile, PsiFeatureDriftProfile,
10};
11use std::collections::HashMap;
12
13#[derive(Default)]
14pub struct PsiMonitor {}
15
16impl CategoricalFeatureHelpers for PsiMonitor {}
17
18impl PsiMonitor {
19    pub fn new() -> Self {
20        PsiMonitor {}
21    }
22
23    fn compute_bin_count<F>(
24        &self,
25        array: &ArrayView<F, Ix1>,
26        lower_threshold: &f64,
27        upper_threshold: &f64,
28    ) -> usize
29    where
30        F: Float + FromPrimitive,
31        F: Into<f64>,
32    {
33        array
34            .iter()
35            .filter(|&&value| value.into() > *lower_threshold && value.into() <= *upper_threshold)
36            .count()
37    }
38
39    fn compute_deciles<F>(&self, column_vector: &ArrayView1<F>) -> Result<[F; 9], DriftError>
40    where
41        F: Float + Default,
42        F: Into<f64>,
43    {
44        // TODO: Explore using ndarray_stats quantiles instead of manual computation
45        if column_vector.len() < 10 {
46            return Err(DriftError::NotEnoughDecileValuesError);
47        }
48
49        let sorted_column_vector = column_vector
50            .iter()
51            .sorted_by(|a, b| a.partial_cmp(b).unwrap()) // Use partial_cmp and unwrap since we assume no NaNs
52            .cloned()
53            .collect_vec();
54
55        let n = sorted_column_vector.len();
56        let mut deciles: [F; 9] = Default::default();
57
58        for i in 1..=9 {
59            let index = ((i as f32 * (n as f32 - 1.0)) / 10.0).floor() as usize;
60            deciles[i - 1] = sorted_column_vector[index];
61        }
62        let decile_vec: [F; 9] = deciles
63            .to_vec()
64            .try_into()
65            .map_err(|_| DriftError::ConvertDecileToArray)?;
66
67        Ok(decile_vec)
68    }
69
70    fn create_categorical_bins<F>(&self, column_vector: &ArrayView<F, Ix1>) -> Vec<Bin>
71    where
72        F: Float + FromPrimitive + Default + Sync,
73        F: Into<f64>,
74    {
75        let vector_len = column_vector.len() as f64;
76        let mut counts: HashMap<usize, usize> = HashMap::new();
77
78        for &value in column_vector.iter() {
79            let key = Into::<f64>::into(value) as usize;
80            *counts.entry(key).or_insert(0) += 1;
81        }
82
83        counts
84            .into_par_iter()
85            .map(|(id, count)| Bin {
86                id,
87                lower_limit: None,
88                upper_limit: None,
89                proportion: (count as f64) / vector_len,
90            })
91            .collect()
92    }
93
94    fn create_numeric_bins<F>(&self, column_vector: &ArrayView1<F>) -> Result<Vec<Bin>, DriftError>
95    where
96        F: Float + FromPrimitive + Default + Sync,
97        F: Into<f64>,
98    {
99        let deciles = self.compute_deciles(column_vector)?;
100
101        let bins: Vec<Bin> = (0..=deciles.len())
102            .into_par_iter()
103            .map(|decile| {
104                let lower = if decile == 0 {
105                    F::neg_infinity()
106                } else {
107                    deciles[decile - 1]
108                };
109                let upper = if decile == deciles.len() {
110                    F::infinity()
111                } else {
112                    deciles[decile]
113                };
114                let bin_count = self.compute_bin_count(column_vector, &lower.into(), &upper.into());
115                Bin {
116                    id: decile + 1,
117                    lower_limit: Some(lower.into()),
118                    upper_limit: Some(upper.into()),
119                    proportion: (bin_count as f64) / (column_vector.len() as f64),
120                }
121            })
122            .collect();
123        Ok(bins)
124    }
125
126    fn create_bins<F>(
127        &self,
128        feature_name: &String,
129        column_vector: &ArrayView<F, Ix1>,
130        drift_config: &PsiDriftConfig,
131    ) -> Result<(Vec<Bin>, BinType), DriftError>
132    where
133        F: Float + FromPrimitive + Default + Sync,
134        F: Into<f64>,
135    {
136        match &drift_config.categorical_features {
137            Some(features) if features.contains(feature_name) => {
138                // Process as categorical
139                Ok((
140                    self.create_categorical_bins(column_vector),
141                    BinType::Category,
142                ))
143            }
144            _ => {
145                // Process as continuous
146                Ok((self.create_numeric_bins(column_vector)?, BinType::Numeric))
147            }
148        }
149    }
150
151    fn create_psi_feature_drift_profile<F>(
152        &self,
153        feature_name: String,
154        column_vector: &ArrayView<F, Ix1>,
155        drift_config: &PsiDriftConfig,
156    ) -> Result<PsiFeatureDriftProfile, DriftError>
157    where
158        F: Float + Sync + FromPrimitive + Default,
159        F: Into<f64>,
160    {
161        let (bins, bin_type) = self.create_bins(&feature_name, column_vector, drift_config)?;
162
163        Ok(PsiFeatureDriftProfile {
164            id: feature_name,
165            bins,
166            timestamp: chrono::Utc::now(),
167            bin_type,
168        })
169    }
170
171    pub fn create_2d_drift_profile<F>(
172        &self,
173        features: &[String],
174        array: &ArrayView2<F>,
175        drift_config: &PsiDriftConfig,
176    ) -> Result<PsiDriftProfile, DriftError>
177    where
178        F: Float + Sync + FromPrimitive + Default,
179        F: Into<f64>,
180    {
181        let mut psi_feature_drift_profiles = HashMap::new();
182
183        // Ensure that the number of features matches the number of columns in the array
184        assert_eq!(
185            features.len(),
186            array.shape()[1],
187            "Feature count must match column count."
188        );
189
190        let profile_vector = array
191            .axis_iter(Axis(1))
192            .zip(features)
193            .collect_vec()
194            .into_par_iter()
195            .map(|(column_vector, feature_name)| {
196                self.create_psi_feature_drift_profile(
197                    feature_name.to_string(),
198                    &column_vector,
199                    drift_config,
200                )
201            })
202            .collect::<Result<Vec<_>, _>>()?;
203
204        profile_vector
205            .into_iter()
206            .zip(features)
207            .for_each(|(profile, feature_name)| {
208                psi_feature_drift_profiles.insert(feature_name.clone(), profile);
209            });
210
211        Ok(PsiDriftProfile::new(
212            psi_feature_drift_profiles,
213            drift_config.clone(),
214            None,
215        ))
216    }
217
218    fn compute_psi_proportion_pairs<F>(
219        &self,
220        column_vector: &ArrayView<F, Ix1>,
221        bin: &Bin,
222        feature_is_categorical: bool,
223    ) -> Result<(f64, f64), DriftError>
224    where
225        F: Float + FromPrimitive,
226        F: Into<f64>,
227    {
228        if feature_is_categorical {
229            let bin_count = column_vector
230                .iter()
231                .filter(|&&value| value.into() == bin.id as f64)
232                .count();
233            return Ok((
234                bin.proportion,
235                (bin_count as f64) / (column_vector.len() as f64),
236            ));
237        }
238
239        let bin_count = self.compute_bin_count(
240            column_vector,
241            &bin.lower_limit.unwrap(),
242            &bin.upper_limit.unwrap(),
243        );
244
245        Ok((
246            bin.proportion,
247            (bin_count as f64) / (column_vector.len() as f64),
248        ))
249    }
250
251    pub fn compute_psi(proportion_pairs: &[(f64, f64)]) -> f64 {
252        let epsilon = 1e-10;
253        proportion_pairs
254            .iter()
255            .map(|(p, q)| {
256                let p_adj = p + epsilon;
257                let q_adj = q + epsilon;
258                (p_adj - q_adj) * (p_adj / q_adj).ln()
259            })
260            .sum()
261    }
262
263    fn compute_feature_drift<F>(
264        &self,
265        column_vector: &ArrayView<F, Ix1>,
266        feature_drift_profile: &PsiFeatureDriftProfile,
267        feature_is_categorical: bool,
268    ) -> Result<f64, DriftError>
269    where
270        F: Float + Sync + FromPrimitive,
271        F: Into<f64>,
272    {
273        let bins = &feature_drift_profile.bins;
274        let feature_proportions: Vec<(f64, f64)> = bins
275            .into_par_iter()
276            .map(|bin| {
277                self.compute_psi_proportion_pairs(column_vector, bin, feature_is_categorical)
278            })
279            .collect::<Result<Vec<(f64, f64)>, DriftError>>()?;
280
281        Ok(PsiMonitor::compute_psi(&feature_proportions))
282    }
283
284    fn check_features<F>(
285        &self,
286        features: &[String],
287        array: &ArrayView2<F>,
288        drift_profile: &PsiDriftProfile,
289    ) -> Result<(), DriftError>
290    where
291        F: Float + Sync + FromPrimitive,
292        F: Into<f64>,
293    {
294        assert_eq!(
295            features.len(),
296            array.shape()[1],
297            "Feature count must match column count."
298        );
299
300        features
301            .iter()
302            .try_for_each(|feature_name| {
303                if !drift_profile.features.contains_key(feature_name) {
304                    // Collect all the keys from the drift profile into a comma-separated string
305                    let available_keys = drift_profile
306                        .features
307                        .keys()
308                        .cloned()
309                        .collect::<Vec<_>>()
310                        .join(", ");
311
312                    return Err(DriftError::RunTimeError(
313                        format!(
314                            "Feature mismatch, feature '{}' not found. Available features in the drift profile: {}",
315                            feature_name, available_keys
316                        ),
317                    ));
318                }
319                Ok(())
320            })
321    }
322
323    pub fn compute_drift<F>(
324        &self,
325        features: &[String],
326        array: &ArrayView2<F>,
327        drift_profile: &PsiDriftProfile,
328    ) -> Result<PsiDriftMap, DriftError>
329    where
330        F: Float + Sync + FromPrimitive,
331        F: Into<f64>,
332    {
333        self.check_features(features, array, drift_profile)?;
334
335        let drift_values: Vec<_> = array
336            .axis_iter(Axis(1))
337            .zip(features)
338            .collect_vec()
339            .into_par_iter()
340            .map(|(column_vector, feature_name)| {
341                let feature_is_categorical = drift_profile
342                    .config
343                    .categorical_features
344                    .as_ref()
345                    .is_some_and(|features| features.contains(feature_name));
346                self.compute_feature_drift(
347                    &column_vector,
348                    drift_profile.features.get(feature_name).unwrap(),
349                    feature_is_categorical,
350                )
351            })
352            .collect::<Result<Vec<f64>, DriftError>>()?;
353
354        let mut psi_drift_features = HashMap::new();
355
356        drift_values
357            .into_iter()
358            .zip(features)
359            .for_each(|(drift_value, feature_name)| {
360                psi_drift_features.insert(feature_name.clone(), drift_value);
361            });
362
363        Ok(PsiDriftMap {
364            features: psi_drift_features,
365            name: drift_profile.config.name.clone(),
366            space: drift_profile.config.space.clone(),
367            version: drift_profile.config.version.clone(),
368        })
369    }
370}
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use ndarray::Array;
375    use ndarray_rand::rand_distr::Uniform;
376    use ndarray_rand::RandomExt;
377
378    #[test]
379    fn test_check_features_all_exist() {
380        let psi_monitor = PsiMonitor::default();
381
382        let array = Array::random((1030, 3), Uniform::new(0., 10.));
383
384        let features = vec![
385            "feature_1".to_string(),
386            "feature_2".to_string(),
387            "feature_3".to_string(),
388        ];
389
390        let profile = psi_monitor
391            .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
392            .unwrap();
393        assert_eq!(profile.features.len(), 3);
394
395        let result = psi_monitor.check_features(&features, &array.view(), &profile);
396
397        // Assert that the result is Ok
398        assert!(result.is_ok());
399    }
400
401    #[test]
402    fn test_compute_psi_basic() {
403        let proportions = vec![(0.3, 0.2), (0.4, 0.4), (0.3, 0.4)];
404
405        let result = PsiMonitor::compute_psi(&proportions);
406
407        // Manually compute expected PSI for this case
408        let expected_psi = (0.3 - 0.2) * (0.3 / 0.2).ln()
409            + (0.4 - 0.4) * (0.4 / 0.4).ln()
410            + (0.3 - 0.4) * (0.3 / 0.4).ln();
411
412        assert!((result - expected_psi).abs() < 1e-6);
413    }
414
415    #[test]
416    fn test_compute_bin_count() {
417        let psi_monitor = PsiMonitor::default();
418
419        let data = Array1::from_vec(vec![1.0, 2.5, 3.7, 5.0, 6.3, 8.1]);
420
421        let lower_threshold = 2.0;
422        let upper_threshold = 6.0;
423
424        let result =
425            psi_monitor.compute_bin_count(&data.view(), &lower_threshold, &upper_threshold);
426
427        // Check that it counts the correct number of elements within the bin
428        // In this case, 2.5, 3.7, and 5.0 should be counted (3 elements)
429        assert_eq!(result, 3);
430    }
431
432    #[test]
433    fn test_compute_psi_proportion_pairs_categorical() {
434        let psi_monitor = PsiMonitor::default();
435
436        let cat_vector = Array::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
437
438        let cat_zero_bin = Bin {
439            id: 0,
440            lower_limit: None,
441            upper_limit: None,
442            proportion: 0.4,
443        };
444
445        let (_, prod_proportion) = psi_monitor
446            .compute_psi_proportion_pairs(&cat_vector.view(), &cat_zero_bin, true)
447            .unwrap();
448
449        let expected_prod_proportion = 0.5;
450
451        assert!(
452            (prod_proportion - expected_prod_proportion).abs() < 1e-9,
453            "prod_proportion was expected to be 50%"
454        );
455    }
456
457    #[test]
458    fn test_compute_psi_proportion_pairs_non_categorical() {
459        let psi_monitor = PsiMonitor::default();
460
461        let vector = Array::from_vec(vec![
462            12.0, 11.0, 10.0, 1.0, 10.0, 21.0, 19.0, 12.0, 12.0, 23.0,
463        ]);
464
465        let bin = Bin {
466            id: 1,
467            lower_limit: Some(0.0),
468            upper_limit: Some(11.0),
469            proportion: 0.4,
470        };
471
472        let (_, prod_proportion) = psi_monitor
473            .compute_psi_proportion_pairs(&vector.view(), &bin, false)
474            .unwrap();
475
476        let expected_prod_proportion = 0.4;
477
478        assert!(
479            (prod_proportion - expected_prod_proportion).abs() < 1e-9,
480            "prod_proportion was expected to be 40%"
481        );
482    }
483
484    #[test]
485    fn test_compute_deciles_with_unsorted_input() {
486        let psi_monitor = PsiMonitor::default();
487
488        let unsorted_vector = Array::from_vec(vec![
489            120.0, 1.0, 33.0, 71.0, 15.0, 59.0, 8.0, 62.0, 4.0, 21.0, 10.0, 2.0, 344.0, 437.0,
490            53.0, 39.0, 83.0, 6.0, 4.30, 2.0,
491        ]);
492
493        let column_view = unsorted_vector.view();
494
495        let result = psi_monitor.compute_deciles(&column_view);
496
497        let expected_deciles: [f64; 9] = [2.0, 4.0, 6.0, 10.0, 21.0, 39.0, 59.0, 71.0, 120.0];
498
499        assert_eq!(
500            result.unwrap().as_ref(),
501            expected_deciles.as_ref(),
502            "Deciles computed incorrectly for unsorted input"
503        );
504    }
505
506    #[test]
507    fn test_create_bins_non_categorical() {
508        let psi_monitor = PsiMonitor::default();
509
510        let non_categorical_data = Array::from_vec(vec![
511            120.0, 1.0, 33.0, 71.0, 15.0, 59.0, 8.0, 62.0, 4.0, 21.0, 10.0, 2.0, 344.0, 437.0,
512            53.0, 39.0, 83.0, 6.0, 4.30, 2.0,
513        ]);
514
515        let result = psi_monitor.create_numeric_bins(&ArrayView::from(&non_categorical_data));
516
517        assert!(result.is_ok());
518        let bins = result.unwrap();
519        assert_eq!(bins.len(), 10);
520    }
521
522    #[test]
523    fn test_create_bins_categorical() {
524        let psi_monitor = PsiMonitor::default();
525
526        let categorical_data = Array::from_vec(vec![
527            1.0, 1.0, 2.0, 3.0, 2.0, 3.0, 2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 2.0, 3.0, 1.0,
528            1.0,
529        ]);
530
531        let bins = psi_monitor.create_categorical_bins(&ArrayView::from(&categorical_data));
532        assert_eq!(bins.len(), 3);
533    }
534
535    #[test]
536    fn test_create_2d_drift_profile() {
537        // create 2d array
538        let array = Array::random((1030, 3), Uniform::new(0., 10.));
539
540        // cast array to f32
541        let array = array.mapv(|x| x as f32);
542
543        let features = vec![
544            "feature_1".to_string(),
545            "feature_2".to_string(),
546            "feature_3".to_string(),
547        ];
548
549        let monitor = PsiMonitor::default();
550        let profile = monitor
551            .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
552            .unwrap();
553
554        assert_eq!(profile.features.len(), 3);
555    }
556
557    #[test]
558    fn test_compute_drift() {
559        // create 2d array
560        let array = Array::random((1030, 3), Uniform::new(0., 10.));
561
562        // cast array to f32
563        let array = array.mapv(|x| x as f32);
564
565        let features = vec![
566            "feature_1".to_string(),
567            "feature_2".to_string(),
568            "feature_3".to_string(),
569        ];
570
571        let monitor = PsiMonitor::default();
572
573        let profile = monitor
574            .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
575            .unwrap();
576
577        let drift_map = monitor
578            .compute_drift(&features, &array.view(), &profile)
579            .unwrap();
580
581        assert_eq!(drift_map.features.len(), 3);
582
583        // assert that the drift values are all 0.0
584        drift_map
585            .features
586            .values()
587            .for_each(|value| assert!(*value == 0.0));
588
589        // create new array that has drifted values
590        let mut new_array = Array::random((1030, 3), Uniform::new(0., 10.)).mapv(|x| x as f32);
591        new_array.slice_mut(s![.., 0]).mapv_inplace(|x| x + 0.01);
592
593        let new_drift_map = monitor
594            .compute_drift(&features, &new_array.view(), &profile)
595            .unwrap();
596
597        // assert that the drift values are all greater than 0.0
598        new_drift_map
599            .features
600            .values()
601            .for_each(|value| assert!(*value > 0.0));
602    }
603}