1use crate::error::DriftError;
2use crate::utils::CategoricalFeatureHelpers;
3use itertools::Itertools;
4use ndarray::prelude::*;
5use ndarray::Axis;
6use num_traits::{Float, FromPrimitive};
7use rayon::prelude::*;
8use scouter_types::psi::{
9 Bin, BinType, PsiDriftConfig, PsiDriftMap, PsiDriftProfile, PsiFeatureDriftProfile,
10};
11use std::collections::HashMap;
12
13#[derive(Default)]
14pub struct PsiMonitor {}
15
16impl CategoricalFeatureHelpers for PsiMonitor {}
17
18impl PsiMonitor {
19 pub fn new() -> Self {
20 PsiMonitor {}
21 }
22
23 fn compute_bin_count<F>(
24 &self,
25 array: &ArrayView<F, Ix1>,
26 lower_threshold: &f64,
27 upper_threshold: &f64,
28 ) -> usize
29 where
30 F: Float + FromPrimitive,
31 F: Into<f64>,
32 {
33 array
34 .iter()
35 .filter(|&&value| value.into() > *lower_threshold && value.into() <= *upper_threshold)
36 .count()
37 }
38
39 fn compute_deciles<F>(&self, column_vector: &ArrayView1<F>) -> Result<[F; 9], DriftError>
40 where
41 F: Float + Default,
42 F: Into<f64>,
43 {
44 if column_vector.len() < 10 {
46 return Err(DriftError::NotEnoughDecileValuesError);
47 }
48
49 let sorted_column_vector = column_vector
50 .iter()
51 .sorted_by(|a, b| a.partial_cmp(b).unwrap()) .cloned()
53 .collect_vec();
54
55 let n = sorted_column_vector.len();
56 let mut deciles: [F; 9] = Default::default();
57
58 for i in 1..=9 {
59 let index = ((i as f32 * (n as f32 - 1.0)) / 10.0).floor() as usize;
60 deciles[i - 1] = sorted_column_vector[index];
61 }
62 let decile_vec: [F; 9] = deciles
63 .to_vec()
64 .try_into()
65 .map_err(|_| DriftError::ConvertDecileToArray)?;
66
67 Ok(decile_vec)
68 }
69
70 fn create_categorical_bins<F>(&self, column_vector: &ArrayView<F, Ix1>) -> Vec<Bin>
71 where
72 F: Float + FromPrimitive + Default + Sync,
73 F: Into<f64>,
74 {
75 let vector_len = column_vector.len() as f64;
76 let mut counts: HashMap<usize, usize> = HashMap::new();
77
78 for &value in column_vector.iter() {
79 let key = Into::<f64>::into(value) as usize;
80 *counts.entry(key).or_insert(0) += 1;
81 }
82
83 counts
84 .into_par_iter()
85 .map(|(id, count)| Bin {
86 id,
87 lower_limit: None,
88 upper_limit: None,
89 proportion: (count as f64) / vector_len,
90 })
91 .collect()
92 }
93
94 fn create_numeric_bins<F>(&self, column_vector: &ArrayView1<F>) -> Result<Vec<Bin>, DriftError>
95 where
96 F: Float + FromPrimitive + Default + Sync,
97 F: Into<f64>,
98 {
99 let deciles = self.compute_deciles(column_vector)?;
100
101 let bins: Vec<Bin> = (0..=deciles.len())
102 .into_par_iter()
103 .map(|decile| {
104 let lower = if decile == 0 {
105 F::neg_infinity()
106 } else {
107 deciles[decile - 1]
108 };
109 let upper = if decile == deciles.len() {
110 F::infinity()
111 } else {
112 deciles[decile]
113 };
114 let bin_count = self.compute_bin_count(column_vector, &lower.into(), &upper.into());
115 Bin {
116 id: decile + 1,
117 lower_limit: Some(lower.into()),
118 upper_limit: Some(upper.into()),
119 proportion: (bin_count as f64) / (column_vector.len() as f64),
120 }
121 })
122 .collect();
123 Ok(bins)
124 }
125
126 fn create_bins<F>(
127 &self,
128 feature_name: &String,
129 column_vector: &ArrayView<F, Ix1>,
130 drift_config: &PsiDriftConfig,
131 ) -> Result<(Vec<Bin>, BinType), DriftError>
132 where
133 F: Float + FromPrimitive + Default + Sync,
134 F: Into<f64>,
135 {
136 match &drift_config.categorical_features {
137 Some(features) if features.contains(feature_name) => {
138 Ok((
140 self.create_categorical_bins(column_vector),
141 BinType::Category,
142 ))
143 }
144 _ => {
145 Ok((self.create_numeric_bins(column_vector)?, BinType::Numeric))
147 }
148 }
149 }
150
151 fn create_psi_feature_drift_profile<F>(
152 &self,
153 feature_name: String,
154 column_vector: &ArrayView<F, Ix1>,
155 drift_config: &PsiDriftConfig,
156 ) -> Result<PsiFeatureDriftProfile, DriftError>
157 where
158 F: Float + Sync + FromPrimitive + Default,
159 F: Into<f64>,
160 {
161 let (bins, bin_type) = self.create_bins(&feature_name, column_vector, drift_config)?;
162
163 Ok(PsiFeatureDriftProfile {
164 id: feature_name,
165 bins,
166 timestamp: chrono::Utc::now(),
167 bin_type,
168 })
169 }
170
171 pub fn create_2d_drift_profile<F>(
172 &self,
173 features: &[String],
174 array: &ArrayView2<F>,
175 drift_config: &PsiDriftConfig,
176 ) -> Result<PsiDriftProfile, DriftError>
177 where
178 F: Float + Sync + FromPrimitive + Default,
179 F: Into<f64>,
180 {
181 let mut psi_feature_drift_profiles = HashMap::new();
182
183 assert_eq!(
185 features.len(),
186 array.shape()[1],
187 "Feature count must match column count."
188 );
189
190 let profile_vector = array
191 .axis_iter(Axis(1))
192 .zip(features)
193 .collect_vec()
194 .into_par_iter()
195 .map(|(column_vector, feature_name)| {
196 self.create_psi_feature_drift_profile(
197 feature_name.to_string(),
198 &column_vector,
199 drift_config,
200 )
201 })
202 .collect::<Result<Vec<_>, _>>()?;
203
204 profile_vector
205 .into_iter()
206 .zip(features)
207 .for_each(|(profile, feature_name)| {
208 psi_feature_drift_profiles.insert(feature_name.clone(), profile);
209 });
210
211 Ok(PsiDriftProfile::new(
212 psi_feature_drift_profiles,
213 drift_config.clone(),
214 None,
215 ))
216 }
217
218 fn compute_psi_proportion_pairs<F>(
219 &self,
220 column_vector: &ArrayView<F, Ix1>,
221 bin: &Bin,
222 feature_is_categorical: bool,
223 ) -> Result<(f64, f64), DriftError>
224 where
225 F: Float + FromPrimitive,
226 F: Into<f64>,
227 {
228 if feature_is_categorical {
229 let bin_count = column_vector
230 .iter()
231 .filter(|&&value| value.into() == bin.id as f64)
232 .count();
233 return Ok((
234 bin.proportion,
235 (bin_count as f64) / (column_vector.len() as f64),
236 ));
237 }
238
239 let bin_count = self.compute_bin_count(
240 column_vector,
241 &bin.lower_limit.unwrap(),
242 &bin.upper_limit.unwrap(),
243 );
244
245 Ok((
246 bin.proportion,
247 (bin_count as f64) / (column_vector.len() as f64),
248 ))
249 }
250
251 pub fn compute_psi(proportion_pairs: &[(f64, f64)]) -> f64 {
252 let epsilon = 1e-10;
253 proportion_pairs
254 .iter()
255 .map(|(p, q)| {
256 let p_adj = p + epsilon;
257 let q_adj = q + epsilon;
258 (p_adj - q_adj) * (p_adj / q_adj).ln()
259 })
260 .sum()
261 }
262
263 fn compute_feature_drift<F>(
264 &self,
265 column_vector: &ArrayView<F, Ix1>,
266 feature_drift_profile: &PsiFeatureDriftProfile,
267 feature_is_categorical: bool,
268 ) -> Result<f64, DriftError>
269 where
270 F: Float + Sync + FromPrimitive,
271 F: Into<f64>,
272 {
273 let bins = &feature_drift_profile.bins;
274 let feature_proportions: Vec<(f64, f64)> = bins
275 .into_par_iter()
276 .map(|bin| {
277 self.compute_psi_proportion_pairs(column_vector, bin, feature_is_categorical)
278 })
279 .collect::<Result<Vec<(f64, f64)>, DriftError>>()?;
280
281 Ok(PsiMonitor::compute_psi(&feature_proportions))
282 }
283
284 fn check_features<F>(
285 &self,
286 features: &[String],
287 array: &ArrayView2<F>,
288 drift_profile: &PsiDriftProfile,
289 ) -> Result<(), DriftError>
290 where
291 F: Float + Sync + FromPrimitive,
292 F: Into<f64>,
293 {
294 assert_eq!(
295 features.len(),
296 array.shape()[1],
297 "Feature count must match column count."
298 );
299
300 features
301 .iter()
302 .try_for_each(|feature_name| {
303 if !drift_profile.features.contains_key(feature_name) {
304 let available_keys = drift_profile
306 .features
307 .keys()
308 .cloned()
309 .collect::<Vec<_>>()
310 .join(", ");
311
312 return Err(DriftError::RunTimeError(
313 format!(
314 "Feature mismatch, feature '{feature_name}' not found. Available features in the drift profile: {available_keys}"
315 ),
316 ));
317 }
318 Ok(())
319 })
320 }
321
322 pub fn compute_drift<F>(
323 &self,
324 features: &[String],
325 array: &ArrayView2<F>,
326 drift_profile: &PsiDriftProfile,
327 ) -> Result<PsiDriftMap, DriftError>
328 where
329 F: Float + Sync + FromPrimitive,
330 F: Into<f64>,
331 {
332 self.check_features(features, array, drift_profile)?;
333
334 let drift_values: Vec<_> = array
335 .axis_iter(Axis(1))
336 .zip(features)
337 .collect_vec()
338 .into_par_iter()
339 .map(|(column_vector, feature_name)| {
340 let feature_is_categorical = drift_profile
341 .config
342 .categorical_features
343 .as_ref()
344 .is_some_and(|features| features.contains(feature_name));
345 self.compute_feature_drift(
346 &column_vector,
347 drift_profile.features.get(feature_name).unwrap(),
348 feature_is_categorical,
349 )
350 })
351 .collect::<Result<Vec<f64>, DriftError>>()?;
352
353 let mut psi_drift_features = HashMap::new();
354
355 drift_values
356 .into_iter()
357 .zip(features)
358 .for_each(|(drift_value, feature_name)| {
359 psi_drift_features.insert(feature_name.clone(), drift_value);
360 });
361
362 Ok(PsiDriftMap {
363 features: psi_drift_features,
364 name: drift_profile.config.name.clone(),
365 space: drift_profile.config.space.clone(),
366 version: drift_profile.config.version.clone(),
367 })
368 }
369}
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use ndarray::Array;
374 use ndarray_rand::rand_distr::Uniform;
375 use ndarray_rand::RandomExt;
376
377 #[test]
378 fn test_check_features_all_exist() {
379 let psi_monitor = PsiMonitor::default();
380
381 let array = Array::random((1030, 3), Uniform::new(0., 10.));
382
383 let features = vec![
384 "feature_1".to_string(),
385 "feature_2".to_string(),
386 "feature_3".to_string(),
387 ];
388
389 let profile = psi_monitor
390 .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
391 .unwrap();
392 assert_eq!(profile.features.len(), 3);
393
394 let result = psi_monitor.check_features(&features, &array.view(), &profile);
395
396 assert!(result.is_ok());
398 }
399
400 #[test]
401 fn test_compute_psi_basic() {
402 let proportions = vec![(0.3, 0.2), (0.4, 0.4), (0.3, 0.4)];
403
404 let result = PsiMonitor::compute_psi(&proportions);
405
406 let expected_psi = (0.3 - 0.2) * (0.3 / 0.2).ln()
408 + (0.4 - 0.4) * (0.4 / 0.4).ln()
409 + (0.3 - 0.4) * (0.3 / 0.4).ln();
410
411 assert!((result - expected_psi).abs() < 1e-6);
412 }
413
414 #[test]
415 fn test_compute_bin_count() {
416 let psi_monitor = PsiMonitor::default();
417
418 let data = Array1::from_vec(vec![1.0, 2.5, 3.7, 5.0, 6.3, 8.1]);
419
420 let lower_threshold = 2.0;
421 let upper_threshold = 6.0;
422
423 let result =
424 psi_monitor.compute_bin_count(&data.view(), &lower_threshold, &upper_threshold);
425
426 assert_eq!(result, 3);
429 }
430
431 #[test]
432 fn test_compute_psi_proportion_pairs_categorical() {
433 let psi_monitor = PsiMonitor::default();
434
435 let cat_vector = Array::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
436
437 let cat_zero_bin = Bin {
438 id: 0,
439 lower_limit: None,
440 upper_limit: None,
441 proportion: 0.4,
442 };
443
444 let (_, prod_proportion) = psi_monitor
445 .compute_psi_proportion_pairs(&cat_vector.view(), &cat_zero_bin, true)
446 .unwrap();
447
448 let expected_prod_proportion = 0.5;
449
450 assert!(
451 (prod_proportion - expected_prod_proportion).abs() < 1e-9,
452 "prod_proportion was expected to be 50%"
453 );
454 }
455
456 #[test]
457 fn test_compute_psi_proportion_pairs_non_categorical() {
458 let psi_monitor = PsiMonitor::default();
459
460 let vector = Array::from_vec(vec![
461 12.0, 11.0, 10.0, 1.0, 10.0, 21.0, 19.0, 12.0, 12.0, 23.0,
462 ]);
463
464 let bin = Bin {
465 id: 1,
466 lower_limit: Some(0.0),
467 upper_limit: Some(11.0),
468 proportion: 0.4,
469 };
470
471 let (_, prod_proportion) = psi_monitor
472 .compute_psi_proportion_pairs(&vector.view(), &bin, false)
473 .unwrap();
474
475 let expected_prod_proportion = 0.4;
476
477 assert!(
478 (prod_proportion - expected_prod_proportion).abs() < 1e-9,
479 "prod_proportion was expected to be 40%"
480 );
481 }
482
483 #[test]
484 fn test_compute_deciles_with_unsorted_input() {
485 let psi_monitor = PsiMonitor::default();
486
487 let unsorted_vector = Array::from_vec(vec![
488 120.0, 1.0, 33.0, 71.0, 15.0, 59.0, 8.0, 62.0, 4.0, 21.0, 10.0, 2.0, 344.0, 437.0,
489 53.0, 39.0, 83.0, 6.0, 4.30, 2.0,
490 ]);
491
492 let column_view = unsorted_vector.view();
493
494 let result = psi_monitor.compute_deciles(&column_view);
495
496 let expected_deciles: [f64; 9] = [2.0, 4.0, 6.0, 10.0, 21.0, 39.0, 59.0, 71.0, 120.0];
497
498 assert_eq!(
499 result.unwrap().as_ref(),
500 expected_deciles.as_ref(),
501 "Deciles computed incorrectly for unsorted input"
502 );
503 }
504
505 #[test]
506 fn test_create_bins_non_categorical() {
507 let psi_monitor = PsiMonitor::default();
508
509 let non_categorical_data = Array::from_vec(vec![
510 120.0, 1.0, 33.0, 71.0, 15.0, 59.0, 8.0, 62.0, 4.0, 21.0, 10.0, 2.0, 344.0, 437.0,
511 53.0, 39.0, 83.0, 6.0, 4.30, 2.0,
512 ]);
513
514 let result = psi_monitor.create_numeric_bins(&ArrayView::from(&non_categorical_data));
515
516 assert!(result.is_ok());
517 let bins = result.unwrap();
518 assert_eq!(bins.len(), 10);
519 }
520
521 #[test]
522 fn test_create_bins_categorical() {
523 let psi_monitor = PsiMonitor::default();
524
525 let categorical_data = Array::from_vec(vec![
526 1.0, 1.0, 2.0, 3.0, 2.0, 3.0, 2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 2.0, 3.0, 1.0,
527 1.0,
528 ]);
529
530 let bins = psi_monitor.create_categorical_bins(&ArrayView::from(&categorical_data));
531 assert_eq!(bins.len(), 3);
532 }
533
534 #[test]
535 fn test_create_2d_drift_profile() {
536 let array = Array::random((1030, 3), Uniform::new(0., 10.));
538
539 let array = array.mapv(|x| x as f32);
541
542 let features = vec![
543 "feature_1".to_string(),
544 "feature_2".to_string(),
545 "feature_3".to_string(),
546 ];
547
548 let monitor = PsiMonitor::default();
549 let profile = monitor
550 .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
551 .unwrap();
552
553 assert_eq!(profile.features.len(), 3);
554 }
555
556 #[test]
557 fn test_compute_drift() {
558 let array = Array::random((1030, 3), Uniform::new(0., 10.));
560
561 let array = array.mapv(|x| x as f32);
563
564 let features = vec![
565 "feature_1".to_string(),
566 "feature_2".to_string(),
567 "feature_3".to_string(),
568 ];
569
570 let monitor = PsiMonitor::default();
571
572 let profile = monitor
573 .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
574 .unwrap();
575
576 let drift_map = monitor
577 .compute_drift(&features, &array.view(), &profile)
578 .unwrap();
579
580 assert_eq!(drift_map.features.len(), 3);
581
582 drift_map
584 .features
585 .values()
586 .for_each(|value| assert!(*value == 0.0));
587
588 let mut new_array = Array::random((1030, 3), Uniform::new(0., 10.)).mapv(|x| x as f32);
590 new_array.slice_mut(s![.., 0]).mapv_inplace(|x| x + 0.01);
591
592 let new_drift_map = monitor
593 .compute_drift(&features, &new_array.view(), &profile)
594 .unwrap();
595
596 new_drift_map
598 .features
599 .values()
600 .for_each(|value| assert!(*value > 0.0));
601 }
602}