1use crate::error::DriftError;
2use crate::utils::CategoricalFeatureHelpers;
3use itertools::Itertools;
4use ndarray::prelude::*;
5use ndarray::Axis;
6use num_traits::{Float, FromPrimitive};
7use rayon::prelude::*;
8use scouter_types::psi::{
9 Bin, BinType, PsiDriftConfig, PsiDriftMap, PsiDriftProfile, PsiFeatureDriftProfile,
10};
11use std::collections::HashMap;
12
13#[derive(Default)]
14pub struct PsiMonitor {}
15
16impl CategoricalFeatureHelpers for PsiMonitor {}
17
18impl PsiMonitor {
19 pub fn new() -> Self {
20 PsiMonitor {}
21 }
22
23 fn compute_bin_count<F>(
24 &self,
25 array: &ArrayView<F, Ix1>,
26 lower_threshold: &f64,
27 upper_threshold: &f64,
28 ) -> usize
29 where
30 F: Float + FromPrimitive,
31 F: Into<f64>,
32 {
33 array
34 .iter()
35 .filter(|&&value| value.into() > *lower_threshold && value.into() <= *upper_threshold)
36 .count()
37 }
38
39 fn create_categorical_bins<F>(&self, column_vector: &ArrayView<F, Ix1>) -> Vec<Bin>
40 where
41 F: Float + FromPrimitive + Default + Sync,
42 F: Into<f64>,
43 {
44 let vector_len = column_vector.len() as f64;
45 let mut counts: HashMap<usize, usize> = HashMap::new();
46
47 for &value in column_vector.iter() {
48 let key = Into::<f64>::into(value) as usize;
49 *counts.entry(key).or_insert(0) += 1;
50 }
51
52 counts
53 .into_par_iter()
54 .map(|(id, count)| Bin {
55 id,
56 lower_limit: None,
57 upper_limit: None,
58 proportion: (count as f64) / vector_len,
59 })
60 .collect()
61 }
62
63 fn create_numeric_bins<F>(
64 &self,
65 column_vector: &ArrayView1<F>,
66 drift_config: &PsiDriftConfig,
67 ) -> Result<Vec<Bin>, DriftError>
68 where
69 F: Float + FromPrimitive + Default + Sync,
70 F: Into<f64>,
71 {
72 let edges = drift_config
73 .binning_strategy
74 .compute_edges(column_vector)
75 .map_err(|e| DriftError::BinningError(e.to_string()))?;
76
77 let bins: Vec<Bin> = (0..=edges.len())
78 .into_par_iter()
79 .map(|decile| {
80 let lower = if decile == 0 {
81 F::neg_infinity()
82 } else {
83 edges[decile - 1]
84 };
85 let upper = if decile == edges.len() {
86 F::infinity()
87 } else {
88 edges[decile]
89 };
90 let bin_count = self.compute_bin_count(column_vector, &lower.into(), &upper.into());
91 Bin {
92 id: decile + 1,
93 lower_limit: Some(lower.into()),
94 upper_limit: Some(upper.into()),
95 proportion: (bin_count as f64) / (column_vector.len() as f64),
96 }
97 })
98 .collect();
99 Ok(bins)
100 }
101
102 fn create_bins<F>(
103 &self,
104 feature_name: &String,
105 column_vector: &ArrayView<F, Ix1>,
106 drift_config: &PsiDriftConfig,
107 ) -> Result<(Vec<Bin>, BinType), DriftError>
108 where
109 F: Float + FromPrimitive + Default + Sync,
110 F: Into<f64>,
111 {
112 match &drift_config.categorical_features {
113 Some(features) if features.contains(feature_name) => {
114 Ok((
116 self.create_categorical_bins(column_vector),
117 BinType::Category,
118 ))
119 }
120 _ => {
121 Ok((
123 self.create_numeric_bins(column_vector, drift_config)?,
124 BinType::Numeric,
125 ))
126 }
127 }
128 }
129
130 fn create_psi_feature_drift_profile<F>(
131 &self,
132 feature_name: String,
133 column_vector: &ArrayView<F, Ix1>,
134 drift_config: &PsiDriftConfig,
135 ) -> Result<PsiFeatureDriftProfile, DriftError>
136 where
137 F: Float + Sync + FromPrimitive + Default,
138 F: Into<f64>,
139 {
140 let (bins, bin_type) = self.create_bins(&feature_name, column_vector, drift_config)?;
141
142 Ok(PsiFeatureDriftProfile {
143 id: feature_name,
144 bins,
145 timestamp: chrono::Utc::now(),
146 bin_type,
147 })
148 }
149
150 fn clean_column_vector<F>(column_vector: &ArrayView<F, Ix1>) -> Array<F, Ix1>
151 where
152 F: Float,
153 {
154 Array1::from(
155 column_vector
156 .iter()
157 .filter(|&&x| x.is_finite())
158 .cloned()
159 .collect::<Vec<F>>(),
160 )
161 }
162
163 pub fn create_2d_drift_profile<F>(
164 &self,
165 features: &[String],
166 array: &ArrayView2<F>,
167 drift_config: &PsiDriftConfig,
168 ) -> Result<PsiDriftProfile, DriftError>
169 where
170 F: Float + Sync + FromPrimitive + Default,
171 F: Into<f64>,
172 {
173 let mut psi_feature_drift_profiles = HashMap::new();
174
175 assert_eq!(
177 features.len(),
178 array.shape()[1],
179 "Feature count must match column count."
180 );
181
182 let profile_vector = array
183 .axis_iter(Axis(1))
184 .zip(features)
185 .collect_vec()
186 .into_par_iter()
187 .map(|(column_vector, feature_name)| {
188 let clean_column_vector = Self::clean_column_vector(&column_vector);
189
190 if clean_column_vector.is_empty() {
191 return Err(DriftError::EmptyArrayError(
192 format!("PSI drift profile creation failure, unable to create psi feature drift profile for {feature_name}")
193 ));
194 }
195
196 self.create_psi_feature_drift_profile(
197 feature_name.to_string(),
198 &clean_column_vector.view(),
199 drift_config,
200 )
201 })
202 .collect::<Result<Vec<_>, _>>()?;
203
204 profile_vector
205 .into_iter()
206 .zip(features)
207 .for_each(|(profile, feature_name)| {
208 psi_feature_drift_profiles.insert(feature_name.clone(), profile);
209 });
210
211 Ok(PsiDriftProfile::new(
212 psi_feature_drift_profiles,
213 drift_config.clone(),
214 ))
215 }
216
217 fn compute_psi_proportion_pairs<F>(
218 &self,
219 column_vector: &ArrayView<F, Ix1>,
220 bin: &Bin,
221 feature_is_categorical: bool,
222 ) -> Result<(f64, f64), DriftError>
223 where
224 F: Float + FromPrimitive,
225 F: Into<f64>,
226 {
227 if feature_is_categorical {
228 let bin_count = column_vector
229 .iter()
230 .filter(|&&value| value.into() == bin.id as f64)
231 .count();
232 return Ok((
233 bin.proportion,
234 (bin_count as f64) / (column_vector.len() as f64),
235 ));
236 }
237
238 let bin_count = self.compute_bin_count(
239 column_vector,
240 &bin.lower_limit.unwrap(),
241 &bin.upper_limit.unwrap(),
242 );
243
244 Ok((
245 bin.proportion,
246 (bin_count as f64) / (column_vector.len() as f64),
247 ))
248 }
249
250 pub fn compute_psi(proportion_pairs: &[(f64, f64)]) -> f64 {
251 let epsilon = 1e-10;
252 proportion_pairs
253 .iter()
254 .map(|(p, q)| {
255 let p_adj = p + epsilon;
256 let q_adj = q + epsilon;
257 (p_adj - q_adj) * (p_adj / q_adj).ln()
258 })
259 .sum()
260 }
261
262 fn compute_feature_drift<F>(
263 &self,
264 column_vector: &ArrayView<F, Ix1>,
265 feature_drift_profile: &PsiFeatureDriftProfile,
266 feature_is_categorical: bool,
267 ) -> Result<f64, DriftError>
268 where
269 F: Float + Sync + FromPrimitive,
270 F: Into<f64>,
271 {
272 let bins = &feature_drift_profile.bins;
273 let feature_proportions: Vec<(f64, f64)> = bins
274 .into_par_iter()
275 .map(|bin| {
276 self.compute_psi_proportion_pairs(column_vector, bin, feature_is_categorical)
277 })
278 .collect::<Result<Vec<(f64, f64)>, DriftError>>()?;
279
280 Ok(PsiMonitor::compute_psi(&feature_proportions))
281 }
282
283 fn check_features<F>(
284 &self,
285 features: &[String],
286 array: &ArrayView2<F>,
287 drift_profile: &PsiDriftProfile,
288 ) -> Result<(), DriftError>
289 where
290 F: Float + Sync + FromPrimitive,
291 F: Into<f64>,
292 {
293 assert_eq!(
294 features.len(),
295 array.shape()[1],
296 "Feature count must match column count."
297 );
298
299 features
300 .iter()
301 .try_for_each(|feature_name| {
302 if !drift_profile.features.contains_key(feature_name) {
303 let available_keys = drift_profile
305 .features
306 .keys()
307 .cloned()
308 .collect::<Vec<_>>()
309 .join(", ");
310
311 return Err(DriftError::RunTimeError(
312 format!(
313 "Feature mismatch, feature '{feature_name}' not found. Available features in the drift profile: {available_keys}"
314 ),
315 ));
316 }
317 Ok(())
318 })
319 }
320
321 pub fn compute_drift<F>(
322 &self,
323 features: &[String],
324 array: &ArrayView2<F>,
325 drift_profile: &PsiDriftProfile,
326 ) -> Result<PsiDriftMap, DriftError>
327 where
328 F: Float + Sync + FromPrimitive,
329 F: Into<f64>,
330 {
331 self.check_features(features, array, drift_profile)?;
332
333 let drift_values: Vec<_> = array
334 .axis_iter(Axis(1))
335 .zip(features)
336 .collect_vec()
337 .into_par_iter()
338 .map(|(column_vector, feature_name)| {
339 let feature_is_categorical = drift_profile
340 .config
341 .categorical_features
342 .as_ref()
343 .is_some_and(|features| features.contains(feature_name));
344 self.compute_feature_drift(
345 &column_vector,
346 drift_profile.features.get(feature_name).unwrap(),
347 feature_is_categorical,
348 )
349 })
350 .collect::<Result<Vec<f64>, DriftError>>()?;
351
352 let mut psi_drift_features = HashMap::new();
353
354 drift_values
355 .into_iter()
356 .zip(features)
357 .for_each(|(drift_value, feature_name)| {
358 psi_drift_features.insert(feature_name.clone(), drift_value);
359 });
360
361 Ok(PsiDriftMap {
362 features: psi_drift_features,
363 name: drift_profile.config.name.clone(),
364 space: drift_profile.config.space.clone(),
365 version: drift_profile.config.version.clone(),
366 })
367 }
368}
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use ndarray::Array;
373 use ndarray_rand::rand_distr::Uniform;
374 use ndarray_rand::RandomExt;
375
376 #[test]
377 fn test_check_features_all_exist() {
378 let psi_monitor = PsiMonitor::default();
379
380 let array = Array::random((1030, 3), Uniform::new(0., 10.));
381
382 let features = vec![
383 "feature_1".to_string(),
384 "feature_2".to_string(),
385 "feature_3".to_string(),
386 ];
387
388 let profile = psi_monitor
389 .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
390 .unwrap();
391 assert_eq!(profile.features.len(), 3);
392
393 let result = psi_monitor.check_features(&features, &array.view(), &profile);
394
395 assert!(result.is_ok());
397 }
398
399 #[test]
400 fn test_compute_psi_basic() {
401 let proportions = vec![(0.3, 0.2), (0.4, 0.4), (0.3, 0.4)];
402
403 let result = PsiMonitor::compute_psi(&proportions);
404
405 let expected_psi = (0.3 - 0.2) * (0.3 / 0.2).ln()
407 + (0.4 - 0.4) * (0.4 / 0.4).ln()
408 + (0.3 - 0.4) * (0.3 / 0.4).ln();
409
410 assert!((result - expected_psi).abs() < 1e-6);
411 }
412
413 #[test]
414 fn test_compute_bin_count() {
415 let psi_monitor = PsiMonitor::default();
416
417 let data = Array1::from_vec(vec![1.0, 2.5, 3.7, 5.0, 6.3, 8.1]);
418
419 let lower_threshold = 2.0;
420 let upper_threshold = 6.0;
421
422 let result =
423 psi_monitor.compute_bin_count(&data.view(), &lower_threshold, &upper_threshold);
424
425 assert_eq!(result, 3);
428 }
429
430 #[test]
431 fn test_compute_psi_proportion_pairs_categorical() {
432 let psi_monitor = PsiMonitor::default();
433
434 let cat_vector = Array::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]);
435
436 let cat_zero_bin = Bin {
437 id: 0,
438 lower_limit: None,
439 upper_limit: None,
440 proportion: 0.4,
441 };
442
443 let (_, prod_proportion) = psi_monitor
444 .compute_psi_proportion_pairs(&cat_vector.view(), &cat_zero_bin, true)
445 .unwrap();
446
447 let expected_prod_proportion = 0.5;
448
449 assert!(
450 (prod_proportion - expected_prod_proportion).abs() < 1e-9,
451 "prod_proportion was expected to be 50%"
452 );
453 }
454
455 #[test]
456 fn test_compute_psi_proportion_pairs_non_categorical() {
457 let psi_monitor = PsiMonitor::default();
458
459 let vector = Array::from_vec(vec![
460 12.0, 11.0, 10.0, 1.0, 10.0, 21.0, 19.0, 12.0, 12.0, 23.0,
461 ]);
462
463 let bin = Bin {
464 id: 1,
465 lower_limit: Some(0.0),
466 upper_limit: Some(11.0),
467 proportion: 0.4,
468 };
469
470 let (_, prod_proportion) = psi_monitor
471 .compute_psi_proportion_pairs(&vector.view(), &bin, false)
472 .unwrap();
473
474 let expected_prod_proportion = 0.4;
475
476 assert!(
477 (prod_proportion - expected_prod_proportion).abs() < 1e-9,
478 "prod_proportion was expected to be 40%"
479 );
480 }
481
482 #[test]
483 fn test_create_bins_non_categorical() {
484 let psi_monitor = PsiMonitor::default();
485
486 let non_categorical_data = Array::from_vec(vec![
487 120.0, 1.0, 33.0, 71.0, 15.0, 59.0, 8.0, 62.0, 4.0, 21.0, 10.0, 2.0, 344.0, 437.0,
488 53.0, 39.0, 83.0, 6.0, 4.30, 2.0,
489 ]);
490
491 let result = psi_monitor.create_numeric_bins(
492 &ArrayView::from(&non_categorical_data),
493 &PsiDriftConfig::default(),
494 );
495
496 assert!(result.is_ok());
497 let bins = result.unwrap();
498 assert_eq!(bins.len(), 10);
499 }
500
501 #[test]
502 fn test_create_bins_categorical() {
503 let psi_monitor = PsiMonitor::default();
504
505 let categorical_data = Array::from_vec(vec![
506 1.0, 1.0, 2.0, 3.0, 2.0, 3.0, 2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 3.0, 3.0, 2.0, 3.0, 1.0,
507 1.0,
508 ]);
509
510 let bins = psi_monitor.create_categorical_bins(&ArrayView::from(&categorical_data));
511 assert_eq!(bins.len(), 3);
512 }
513
514 #[test]
515 fn test_create_2d_drift_profile() {
516 let array = Array::random((1030, 3), Uniform::new(0., 10.));
518
519 let array = array.mapv(|x| x as f32);
521
522 let features = vec![
523 "feature_1".to_string(),
524 "feature_2".to_string(),
525 "feature_3".to_string(),
526 ];
527
528 let monitor = PsiMonitor::default();
529 let profile = monitor
530 .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
531 .unwrap();
532
533 assert_eq!(profile.features.len(), 3);
534 }
535
536 #[test]
537 fn test_compute_drift() {
538 let array = Array::random((1030, 3), Uniform::new(0., 10.));
540
541 let array = array.mapv(|x| x as f32);
543
544 let features = vec![
545 "feature_1".to_string(),
546 "feature_2".to_string(),
547 "feature_3".to_string(),
548 ];
549
550 let monitor = PsiMonitor::default();
551
552 let profile = monitor
553 .create_2d_drift_profile(&features, &array.view(), &PsiDriftConfig::default())
554 .unwrap();
555
556 let drift_map = monitor
557 .compute_drift(&features, &array.view(), &profile)
558 .unwrap();
559
560 assert_eq!(drift_map.features.len(), 3);
561
562 drift_map
564 .features
565 .values()
566 .for_each(|value| assert!(*value == 0.0));
567
568 let mut new_array = Array::random((1030, 3), Uniform::new(0., 10.)).mapv(|x| x as f32);
570 new_array.slice_mut(s![.., 0]).mapv_inplace(|x| x + 0.01);
571
572 let new_drift_map = monitor
573 .compute_drift(&features, &new_array.view(), &profile)
574 .unwrap();
575
576 new_drift_map
578 .features
579 .values()
580 .for_each(|value| assert!(*value > 0.0));
581 }
582
583 #[test]
584 fn test_empty_array() {
585 let features = vec!["feature_1".to_string()];
586
587 let data = Array::<f64, _>::zeros((0, 1));
588
589 let monitor = PsiMonitor::default();
590
591 let profile =
592 monitor.create_2d_drift_profile(&features, &data.view(), &PsiDriftConfig::default());
593
594 assert!(profile.is_err());
595 match profile.unwrap_err() {
596 DriftError::EmptyArrayError(msg) => {
597 assert_eq!(msg, "PSI drift profile creation failure, unable to create psi feature drift profile for feature_1");
598 }
599 _ => panic!("Expected EmptyArrayError"),
600 }
601 }
602}