1use crate::error::DriftError;
2use crate::utils::CategoricalFeatureHelpers;
3use itertools::Itertools;
4use ndarray::prelude::*;
5use ndarray::Axis;
6use num_traits::{Float, FromPrimitive};
7use rayon::prelude::*;
8use scouter_types::psi::{
9 Bin, BinType, PsiDriftConfig, PsiDriftMap, PsiDriftProfile, PsiFeatureDriftProfile,
10};
11use std::collections::HashMap;
12
13#[derive(Default)]
14pub struct PsiMonitor {}
15
16impl CategoricalFeatureHelpers for PsiMonitor {}
17
18impl PsiMonitor {
19 pub fn new() -> Self {
20 PsiMonitor {}
21 }
22
23 fn compute_bin_count<F>(
24 &self,
25 array: &ArrayView<F, Ix1>,
26 lower_threshold: &f64,
27 upper_threshold: &f64,
28 ) -> usize
29 where
30 F: Float + FromPrimitive,
31 F: Into<f64>,
32 {
33 array
34 .iter()
35 .filter(|&&value| value.into() > *lower_threshold && value.into() <= *upper_threshold)
36 .count()
37 }
38
39 fn compute_deciles<F>(&self, column_vector: &ArrayView1<F>) -> Result<[F; 9], DriftError>
40 where
41 F: Float + Default,
42 F: Into<f64>,
43 {
44 if column_vector.len() < 10 {
46 return Err(DriftError::NotEnoughDecileValuesError);
47 }
48
49 let sorted_column_vector = column_vector
50 .iter()
51 .sorted_by(|a, b| a.partial_cmp(b).unwrap()) .cloned()
53 .collect_vec();
54
55 let n = sorted_column_vector.len();
56 let mut deciles: [F; 9] = Default::default();
57
58 for i in 1..=9 {
59 let index = ((i as f32 * (n as f32 - 1.0)) / 10.0).floor() as usize;
60 deciles[i - 1] = sorted_column_vector[index];
61 }
62 let decile_vec: [F; 9] = deciles
63 .to_vec()
64 .try_into()
65 .map_err(|_| DriftError::ConvertDecileToArray)?;
66
67 Ok(decile_vec)
68 }
69
70 fn create_categorical_bins<F>(&self, column_vector: &ArrayView<F, Ix1>) -> Vec<Bin>
71 where
72 F: Float + FromPrimitive + Default + Sync,
73 F: Into<f64>,
74 {
75 let vector_len = column_vector.len() as f64;
76 let mut counts: HashMap<usize, usize> = HashMap::new();
77
78 for &value in column_vector.iter() {
79 let key = Into::<f64>::into(value) as usize;
80 *counts.entry(key).or_insert(0) += 1;
81 }
82
83 counts
84 .into_par_iter()
85 .map(|(id, count)| Bin {
86 id,
87 lower_limit: None,
88 upper_limit: None,
89 proportion: (count as f64) / vector_len,
90 })
91 .collect()
92 }
93
94 fn create_numeric_bins<F>(&self, column_vector: &ArrayView1<F>) -> Result<Vec<Bin>, DriftError>
95 where
96 F: Float + FromPrimitive + Default + Sync,
97 F: Into<f64>,
98 {
99 let deciles = self.compute_deciles(column_vector)?;
100
101 let bins: Vec<Bin> = (0..=deciles.len())
102 .into_par_iter()
103 .map(|decile| {
104 let lower = if decile == 0 {
105 F::neg_infinity()
106 } else {
107 deciles[decile - 1]
108 };
109 let upper = if decile == deciles.len() {
110 F::infinity()
111 } else {
112 deciles[decile]
113 };
114 let bin_count = self.compute_bin_count(column_vector, &lower.into(), &upper.into());
115 Bin {
116 id: decile + 1,
117 lower_limit: Some(lower.into()),
118 upper_limit: Some(upper.into()),
119 proportion: (bin_count as f64) / (column_vector.len() as f64),
120 }
121 })
122 .collect();
123 Ok(bins)
124 }
125
126 fn create_bins<F>(
127 &self,
128 feature_name: &String,
129 column_vector: &ArrayView<F, Ix1>,
130 drift_config: &PsiDriftConfig,
131 ) -> Result<(Vec<Bin>, BinType), DriftError>
132 where
133 F: Float + FromPrimitive + Default + Sync,
134 F: Into<f64>,
135 {
136 match &drift_config.categorical_features {
137 Some(features) if features.contains(feature_name) => {
138 Ok((
140 self.create_categorical_bins(column_vector),
141 BinType::Category,
142 ))
143 }
144 _ => {
145 Ok((self.create_numeric_bins(column_vector)?, BinType::Numeric))
147 }
148 }
149 }
150
151 fn create_psi_feature_drift_profile<F>(
152 &self,
153 feature_name: String,
154 column_vector: &ArrayView<F, Ix1>,
155 drift_config: &PsiDriftConfig,
156 ) -> Result<PsiFeatureDriftProfile, DriftError>
157 where
158 F: Float + Sync + FromPrimitive + Default,
159 F: Into<f64>,
160 {
161 let (bins, bin_type) = self.create_bins(&feature_name, column_vector, drift_config)?;
162
163 Ok(PsiFeatureDriftProfile {
164 id: feature_name,
165 bins,
166 timestamp: chrono::Utc::now(),
167 bin_type,
168 })
169 }
170
171 pub fn create_2d_drift_profile<F>(
172 &self,
173 features: &[String],
174 array: &ArrayView2<F>,
175 drift_config: &PsiDriftConfig,
176 ) -> Result<PsiDriftProfile, DriftError>
177 where
178 F: Float + Sync + FromPrimitive + Default,
179 F: Into<f64>,
180 {
181 let mut psi_feature_drift_profiles = HashMap::new();
182
183 assert_eq!(
185 features.len(),
186 array.shape()[1],
187 "Feature count must match column count."
188 );
189
190 let profile_vector = array
191 .axis_iter(Axis(1))
192 .zip(features)
193 .collect_vec()
194 .into_par_iter()
195 .map(|(column_vector, feature_name)| {
196 self.create_psi_feature_drift_profile(
197 feature_name.to_string(),
198 &column_vector,
199 drift_config,
200 )
201 })
202 .collect::<Result<Vec<_>, _>>()?;
203
204 profile_vector
205 .into_iter()
206 .zip(features)
207 .for_each(|(profile, feature_name)| {
208 psi_feature_drift_profiles.insert(feature_name.clone(), profile);
209 });
210
211 Ok(PsiDriftProfile::new(
212 psi_feature_drift_profiles,
213 drift_config.clone(),
214 None,
215 ))
216 }
217
218 fn compute_psi_proportion_pairs<F>(
219 &self,
220 column_vector: &ArrayView<F, Ix1>,
221 bin: &Bin,
222 feature_is_categorical: bool,
223 ) -> Result<(f64, f64), DriftError>
224 where
225 F: Float + FromPrimitive,
226 F: Into<f64>,
227 {
228 if feature_is_categorical {
229 let bin_count = column_vector
230 .iter()
231 .filter(|&&value| value.into() == bin.id as f64)
232 .count();
233 return Ok((
234 bin.proportion,
235 (bin_count as f64) / (column_vector.len() as f64),
236 ));
237 }
238
239 let bin_count = self.compute_bin_count(
240 column_vector,
241 &bin.lower_limit.unwrap(),
242 &bin.upper_limit.unwrap(),
243 );
244
245 Ok((
246 bin.proportion,
247 (bin_count as f64) / (column_vector.len() as f64),
248 ))
249 }
250
251 pub fn compute_psi(proportion_pairs: &[(f64, f64)]) -> f64 {
252 let epsilon = 1e-10;
253 proportion_pairs
254 .iter()
255 .map(|(p, q)| {
256 let p_adj = p + epsilon;
257 let q_adj = q + epsilon;
258 (p_adj - q_adj) * (p_adj / q_adj).ln()
259 })
260 .sum()
261 }
262
263 fn compute_feature_drift<F>(
264 &self,
265 column_vector: &ArrayView<F, Ix1>,
266 feature_drift_profile: &PsiFeatureDriftProfile,
267 feature_is_categorical: bool,
268 ) -> Result<f64, DriftError>
269 where
270 F: Float + Sync + FromPrimitive,
271 F: Into<f64>,
272 {
273 let bins = &feature_drift_profile.bins;
274 let feature_proportions: Vec<(f64, f64)> = bins
275 .into_par_iter()
276 .map(|bin| {
277 self.compute_psi_proportion_pairs(column_vector, bin, feature_is_categorical)
278 })
279 .collect::<Result<Vec<(f64, f64)>, DriftError>>()?;
280
281 Ok(PsiMonitor::compute_psi(&feature_proportions))
282 }
283
284 fn check_features<F>(
285 &self,
286 features: &[String],
287 array: &ArrayView2<F>,
288 drift_profile: &PsiDriftProfile,
289 ) -> Result<(), DriftError>
290 where
291 F: Float + Sync + FromPrimitive,
292 F: Into<f64>,
293 {
294 assert_eq!(
295 features.len(),
296 array.shape()[1],
297 "Feature count must match column count."
298 );
299
300 features
301 .iter()
302 .try_for_each(|feature_name| {
303 if !drift_profile.features.contains_key(feature_name) {
304 let available_keys = drift_profile
306 .features
307 .keys()
308 .cloned()
309 .collect::<Vec<_>>()
310 .join(", ");
311
312 return Err(DriftError::RunTimeError(
313 format!(
314 "Feature mismatch, feature '{}' not found. Available features in the drift profile: {}",
315 feature_name, available_keys
316 ),
317 ));
318 }
319 Ok(())
320 })
321 }
322
323 pub fn compute_drift<F>(
324 &self,
325 features: &[String],
326 array: &ArrayView2<F>,
327 drift_profile: &PsiDriftProfile,
328 ) -> Result<PsiDriftMap, DriftError>
329 where
330 F: Float + Sync + FromPrimitive,
331 F: Into<f64>,
332 {
333 self.check_features(features, array, drift_profile)?;
334
335 let drift_values: Vec<_> = array
336 .axis_iter(Axis(1))
337 .zip(features)
338 .collect_vec()
339 .into_par_iter()
340 .map(|(column_vector, feature_name)| {
341 let feature_is_categorical = drift_profile
342 .config
343 .categorical_features
344 .as_ref()
345 .is_some_and(|features| features.contains(feature_name));
346 self.compute_feature_drift(
347 &column_vector,
348 drift_profile.features.get(feature_name).unwrap(),
349 feature_is_categorical,
350 )
351 })
352 .collect::<Result<Vec<f64>, DriftError>>()?;
353
354 let mut psi_drift_features = HashMap::new();
355
356 drift_values
357 .into_iter()
358 .zip(features)
359 .for_each(|(drift_value, feature_name)| {
360 psi_drift_features.insert(feature_name.clone(), drift_value);
361 });
362
363 Ok(PsiDriftMap {
364 features: psi_drift_features,
365 name: drift_profile.config.name.clone(),
366 space: drift_profile.config.space.clone(),
367 version: drift_profile.config.version.clone(),
368 })
369 }
370}
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use ndarray::Array;
375 use ndarray_rand::rand_distr::Uniform;
376 use ndarray_rand::RandomExt;
377
378 #[test]
379 fn test_check_features_all_exist() {
380 let psi_monitor = PsiMonitor::default();
381
382 let array = Array::random((1030, 3), Uniform::new(0., 10.));
383
384 let features = vec![
385 "feature_1".to_string(),
386 "feature_2".to_string(),
387 "feature_3".to_string(),
388 ];
389
390 let profile = psi_monitor
391 .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
392 .unwrap();
393 assert_eq!(profile.features.len(), 3);
394
395 let result = psi_monitor.check_features(&features, &array.view(), &profile);
396
397 assert!(result.is_ok());
399 }
400
401 #[test]
402 fn test_compute_psi_basic() {
403 let proportions = vec![(0.3, 0.2), (0.4, 0.4), (0.3, 0.4)];
404
405 let result = PsiMonitor::compute_psi(&proportions);
406
407 let expected_psi = (0.3 - 0.2) * (0.3 / 0.2).ln()
409 + (0.4 - 0.4) * (0.4 / 0.4).ln()
410 + (0.3 - 0.4) * (0.3 / 0.4).ln();
411
412 assert!((result - expected_psi).abs() < 1e-6);
413 }
414
415 #[test]
416 fn test_compute_bin_count() {
417 let psi_monitor = PsiMonitor::default();
418
419 let data = Array1::from_vec(vec![1.0, 2.5, 3.7, 5.0, 6.3, 8.1]);
420
421 let lower_threshold = 2.0;
422 let upper_threshold = 6.0;
423
424 let result =
425 psi_monitor.compute_bin_count(&data.view(), &lower_threshold, &upper_threshold);
426
427 assert_eq!(result, 3);
430 }
431
432 #[test]
433 fn test_compute_psi_proportion_pairs_categorical() {
434 let psi_monitor = PsiMonitor::default();
435
436 let cat_vector = Array::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
437
438 let cat_zero_bin = Bin {
439 id: 0,
440 lower_limit: None,
441 upper_limit: None,
442 proportion: 0.4,
443 };
444
445 let (_, prod_proportion) = psi_monitor
446 .compute_psi_proportion_pairs(&cat_vector.view(), &cat_zero_bin, true)
447 .unwrap();
448
449 let expected_prod_proportion = 0.5;
450
451 assert!(
452 (prod_proportion - expected_prod_proportion).abs() < 1e-9,
453 "prod_proportion was expected to be 50%"
454 );
455 }
456
457 #[test]
458 fn test_compute_psi_proportion_pairs_non_categorical() {
459 let psi_monitor = PsiMonitor::default();
460
461 let vector = Array::from_vec(vec![
462 12.0, 11.0, 10.0, 1.0, 10.0, 21.0, 19.0, 12.0, 12.0, 23.0,
463 ]);
464
465 let bin = Bin {
466 id: 1,
467 lower_limit: Some(0.0),
468 upper_limit: Some(11.0),
469 proportion: 0.4,
470 };
471
472 let (_, prod_proportion) = psi_monitor
473 .compute_psi_proportion_pairs(&vector.view(), &bin, false)
474 .unwrap();
475
476 let expected_prod_proportion = 0.4;
477
478 assert!(
479 (prod_proportion - expected_prod_proportion).abs() < 1e-9,
480 "prod_proportion was expected to be 40%"
481 );
482 }
483
484 #[test]
485 fn test_compute_deciles_with_unsorted_input() {
486 let psi_monitor = PsiMonitor::default();
487
488 let unsorted_vector = Array::from_vec(vec![
489 120.0, 1.0, 33.0, 71.0, 15.0, 59.0, 8.0, 62.0, 4.0, 21.0, 10.0, 2.0, 344.0, 437.0,
490 53.0, 39.0, 83.0, 6.0, 4.30, 2.0,
491 ]);
492
493 let column_view = unsorted_vector.view();
494
495 let result = psi_monitor.compute_deciles(&column_view);
496
497 let expected_deciles: [f64; 9] = [2.0, 4.0, 6.0, 10.0, 21.0, 39.0, 59.0, 71.0, 120.0];
498
499 assert_eq!(
500 result.unwrap().as_ref(),
501 expected_deciles.as_ref(),
502 "Deciles computed incorrectly for unsorted input"
503 );
504 }
505
506 #[test]
507 fn test_create_bins_non_categorical() {
508 let psi_monitor = PsiMonitor::default();
509
510 let non_categorical_data = Array::from_vec(vec![
511 120.0, 1.0, 33.0, 71.0, 15.0, 59.0, 8.0, 62.0, 4.0, 21.0, 10.0, 2.0, 344.0, 437.0,
512 53.0, 39.0, 83.0, 6.0, 4.30, 2.0,
513 ]);
514
515 let result = psi_monitor.create_numeric_bins(&ArrayView::from(&non_categorical_data));
516
517 assert!(result.is_ok());
518 let bins = result.unwrap();
519 assert_eq!(bins.len(), 10);
520 }
521
522 #[test]
523 fn test_create_bins_categorical() {
524 let psi_monitor = PsiMonitor::default();
525
526 let categorical_data = Array::from_vec(vec![
527 1.0, 1.0, 2.0, 3.0, 2.0, 3.0, 2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 2.0, 3.0, 1.0,
528 1.0,
529 ]);
530
531 let bins = psi_monitor.create_categorical_bins(&ArrayView::from(&categorical_data));
532 assert_eq!(bins.len(), 3);
533 }
534
535 #[test]
536 fn test_create_2d_drift_profile() {
537 let array = Array::random((1030, 3), Uniform::new(0., 10.));
539
540 let array = array.mapv(|x| x as f32);
542
543 let features = vec![
544 "feature_1".to_string(),
545 "feature_2".to_string(),
546 "feature_3".to_string(),
547 ];
548
549 let monitor = PsiMonitor::default();
550 let profile = monitor
551 .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
552 .unwrap();
553
554 assert_eq!(profile.features.len(), 3);
555 }
556
557 #[test]
558 fn test_compute_drift() {
559 let array = Array::random((1030, 3), Uniform::new(0., 10.));
561
562 let array = array.mapv(|x| x as f32);
564
565 let features = vec![
566 "feature_1".to_string(),
567 "feature_2".to_string(),
568 "feature_3".to_string(),
569 ];
570
571 let monitor = PsiMonitor::default();
572
573 let profile = monitor
574 .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
575 .unwrap();
576
577 let drift_map = monitor
578 .compute_drift(&features, &array.view(), &profile)
579 .unwrap();
580
581 assert_eq!(drift_map.features.len(), 3);
582
583 drift_map
585 .features
586 .values()
587 .for_each(|value| assert!(*value == 0.0));
588
589 let mut new_array = Array::random((1030, 3), Uniform::new(0., 10.)).mapv(|x| x as f32);
591 new_array.slice_mut(s![.., 0]).mapv_inplace(|x| x + 0.01);
592
593 let new_drift_map = monitor
594 .compute_drift(&features, &new_array.view(), &profile)
595 .unwrap();
596
597 new_drift_map
599 .features
600 .values()
601 .for_each(|value| assert!(*value > 0.0));
602 }
603}