Skip to main content

scouter_drift/psi/
types.rs

1use crate::error::DriftError;
2use scouter_types::psi::{FeatureDistributions, PsiFeatureDriftProfile};
3use serde::{Deserialize, Serialize};
4use std::collections::{BTreeMap, HashMap};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct FeatureBinProportionPairs {
8    pub bins: Vec<String>,
9    pub pairs: Vec<(f64, f64)>,
10}
11
12impl FeatureBinProportionPairs {
13    pub fn from_observed_bin_proportions(
14        observed_bin_proportions: &BTreeMap<i32, f64>,
15        profile: &PsiFeatureDriftProfile,
16    ) -> Result<Self, DriftError> {
17        let (bins, pairs): (Vec<String>, Vec<(f64, f64)>) = profile
18            .bins
19            .iter()
20            .map(|bin| {
21                let observed_proportion = *observed_bin_proportions.get(&bin.id).unwrap_or(&0.0); // It's possible that there is no data for a bin, which would mean 0
22                (bin.id.to_string(), (bin.proportion, observed_proportion))
23            })
24            .unzip();
25
26        Ok(Self { bins, pairs })
27    }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct FeatureBinMapping {
32    pub features: HashMap<String, FeatureBinProportionPairs>,
33}
34
35impl FeatureBinMapping {
36    pub fn from_observed_bin_proportions(
37        observed_bin_proportions: &FeatureDistributions,
38        profiles_to_monitor: &[PsiFeatureDriftProfile],
39    ) -> Result<Self, DriftError> {
40        let features: HashMap<String, FeatureBinProportionPairs> = profiles_to_monitor
41            .iter()
42            .map(|profile| {
43                let proportion_pairs = FeatureBinProportionPairs::from_observed_bin_proportions(
44                    &observed_bin_proportions
45                        .distributions
46                        .get(&profile.id)
47                        .unwrap()
48                        .bins,
49                    profile,
50                )
51                .unwrap();
52                (profile.id.clone(), proportion_pairs)
53            })
54            .collect();
55
56        Ok(Self { features })
57    }
58}