scouter_types/psi/
profile.rs

1#![allow(clippy::useless_conversion)]
2use crate::binning::equal_width::EqualWidthBinning;
3use crate::binning::quantile::QuantileBinning;
4use crate::binning::strategy::BinningStrategy;
5use crate::error::{ProfileError, TypeError};
6use crate::psi::alert::PsiAlertConfig;
7use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version};
8use crate::ProfileRequest;
9use crate::VersionRequest;
10use crate::{
11    DispatchDriftConfig, DriftArgs, DriftType, FeatureMap, FileName, ProfileArgs, ProfileBaseArgs,
12    PyHelperFuncs, DEFAULT_VERSION, MISSING,
13};
14use chrono::Utc;
15use core::fmt::Debug;
16use pyo3::prelude::*;
17use pyo3::types::PyDict;
18use scouter_semver::VersionType;
19use serde::de::{self, MapAccess, Visitor};
20use serde::ser::SerializeStruct;
21use serde::{Deserialize, Deserializer, Serialize, Serializer};
22use serde_json::Value;
23use std::collections::{BTreeMap, HashMap};
24use std::path::PathBuf;
25use tracing::debug;
26
27#[pyclass(eq)]
28#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
29pub enum BinType {
30    Numeric,
31    Category,
32}
33
34#[pyclass]
35#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
36pub struct PsiDriftConfig {
37    #[pyo3(get, set)]
38    pub space: String,
39
40    #[pyo3(get, set)]
41    pub name: String,
42
43    #[pyo3(get, set)]
44    pub version: String,
45
46    #[pyo3(get, set)]
47    pub alert_config: PsiAlertConfig,
48
49    #[pyo3(get)]
50    #[serde(default)]
51    pub feature_map: FeatureMap,
52
53    #[pyo3(get)]
54    #[serde(default = "default_drift_type")]
55    pub drift_type: DriftType,
56
57    #[pyo3(get, set)]
58    pub categorical_features: Option<Vec<String>>,
59
60    pub binning_strategy: BinningStrategy,
61}
62
63fn default_drift_type() -> DriftType {
64    DriftType::Psi
65}
66
67impl PsiDriftConfig {
68    pub fn update_feature_map(&mut self, feature_map: FeatureMap) {
69        self.feature_map = feature_map;
70    }
71}
72
73#[pymethods]
74#[allow(clippy::too_many_arguments)]
75impl PsiDriftConfig {
76    #[new]
77    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, alert_config=PsiAlertConfig::default(), config_path=None, categorical_features=None, binning_strategy=None))]
78    pub fn new(
79        space: &str,
80        name: &str,
81        version: &str,
82        alert_config: PsiAlertConfig,
83        config_path: Option<PathBuf>,
84        categorical_features: Option<Vec<String>>,
85        binning_strategy: Option<&Bound<'_, PyAny>>,
86    ) -> Result<Self, ProfileError> {
87        if let Some(config_path) = config_path {
88            let config = PsiDriftConfig::load_from_json_file(config_path);
89            return config;
90        }
91
92        let binning_strategy = match binning_strategy {
93            None => BinningStrategy::default(),
94            Some(strategy) => {
95                if strategy.is_instance_of::<QuantileBinning>() {
96                    BinningStrategy::QuantileBinning(strategy.extract()?)
97                } else if strategy.is_instance_of::<EqualWidthBinning>() {
98                    BinningStrategy::EqualWidthBinning(strategy.extract()?)
99                } else {
100                    return Err(ProfileError::InvalidBinningStrategyError);
101                }
102            }
103        };
104
105        println!("{binning_strategy:#?}");
106
107        if name == MISSING || space == MISSING {
108            debug!("Name and space were not provided. Defaulting to __missing__");
109        }
110
111        Ok(Self {
112            name: name.to_string(),
113            space: space.to_string(),
114            version: version.to_string(),
115            alert_config,
116            categorical_features,
117            feature_map: FeatureMap::default(),
118            drift_type: DriftType::Psi,
119            binning_strategy,
120        })
121    }
122
123    #[staticmethod]
124    pub fn load_from_json_file(path: PathBuf) -> Result<PsiDriftConfig, ProfileError> {
125        // deserialize the string to a struct
126
127        let file = std::fs::read_to_string(&path)?;
128
129        Ok(serde_json::from_str(&file)?)
130    }
131
132    pub fn __str__(&self) -> String {
133        // serialize the struct to a string
134        PyHelperFuncs::__str__(self)
135    }
136
137    pub fn model_dump_json(&self) -> String {
138        // serialize the struct to a string
139        PyHelperFuncs::__json__(self)
140    }
141
142    #[allow(clippy::too_many_arguments)]
143    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None, categorical_features=None, binning_strategy=None))]
144    pub fn update_config_args(
145        &mut self,
146        space: Option<String>,
147        name: Option<String>,
148        version: Option<String>,
149        alert_config: Option<PsiAlertConfig>,
150        categorical_features: Option<Vec<String>>,
151        binning_strategy: Option<&Bound<'_, PyAny>>,
152    ) -> Result<(), TypeError> {
153        if name.is_some() {
154            self.name = name.ok_or(TypeError::MissingNameError)?;
155        }
156
157        if space.is_some() {
158            self.space = space.ok_or(TypeError::MissingSpaceError)?;
159        }
160
161        if version.is_some() {
162            self.version = version.ok_or(TypeError::MissingVersionError)?;
163        }
164
165        if alert_config.is_some() {
166            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
167        }
168
169        if categorical_features.is_some() {
170            self.categorical_features = categorical_features;
171        }
172
173        if let Some(binning_strategy) = binning_strategy {
174            if binning_strategy.is_instance_of::<QuantileBinning>() {
175                self.binning_strategy =
176                    BinningStrategy::QuantileBinning(binning_strategy.extract()?);
177            } else if binning_strategy.is_instance_of::<EqualWidthBinning>() {
178                self.binning_strategy =
179                    BinningStrategy::EqualWidthBinning(binning_strategy.extract()?);
180            } else {
181                return Err(TypeError::InvalidBinningStrategyError);
182            }
183        }
184
185        Ok(())
186    }
187
188    #[getter]
189    pub fn binning_strategy<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
190        self.binning_strategy.strategy(py)
191    }
192
193    #[setter]
194    pub fn set_binning_strategy(&mut self, strategy: &Bound<'_, PyAny>) -> PyResult<()> {
195        if strategy.is_instance_of::<QuantileBinning>() {
196            self.binning_strategy = BinningStrategy::QuantileBinning(strategy.extract()?);
197        } else if strategy.is_instance_of::<EqualWidthBinning>() {
198            self.binning_strategy = BinningStrategy::EqualWidthBinning(strategy.extract()?);
199        } else {
200            return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
201                "Invalid binning strategy type",
202            ));
203        }
204        Ok(())
205    }
206}
207
208impl Default for PsiDriftConfig {
209    fn default() -> Self {
210        PsiDriftConfig {
211            name: "__missing__".to_string(),
212            space: "__missing__".to_string(),
213            version: DEFAULT_VERSION.to_string(),
214            feature_map: FeatureMap::default(),
215            alert_config: PsiAlertConfig::default(),
216            drift_type: DriftType::Psi,
217            categorical_features: None,
218            binning_strategy: BinningStrategy::default(),
219        }
220    }
221}
222// TODO dry this out
223
224impl DispatchDriftConfig for PsiDriftConfig {
225    fn get_drift_args(&self) -> DriftArgs {
226        DriftArgs {
227            name: self.name.clone(),
228            space: self.space.clone(),
229            version: self.version.clone(),
230            dispatch_config: self.alert_config.dispatch_config.clone(),
231        }
232    }
233}
234
235#[pyclass]
236#[derive(Debug, Clone, PartialEq)]
237pub struct Bin {
238    #[pyo3(get)]
239    pub id: usize,
240
241    #[pyo3(get)]
242    pub lower_limit: Option<f64>,
243
244    #[pyo3(get)]
245    pub upper_limit: Option<f64>,
246
247    #[pyo3(get)]
248    pub proportion: f64,
249}
250
251impl Serialize for Bin {
252    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
253    where
254        S: Serializer,
255    {
256        let mut state = serializer.serialize_struct("Bin", 4)?;
257        state.serialize_field("id", &self.id)?;
258
259        state.serialize_field(
260            "lower_limit",
261            &self.lower_limit.map(|v| {
262                if v.is_infinite() {
263                    serde_json::Value::String(if v.is_sign_positive() {
264                        "inf".to_string()
265                    } else {
266                        "-inf".to_string()
267                    })
268                } else {
269                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
270                }
271            }),
272        )?;
273        state.serialize_field(
274            "upper_limit",
275            &self.upper_limit.map(|v| {
276                if v.is_infinite() {
277                    serde_json::Value::String(if v.is_sign_positive() {
278                        "inf".to_string()
279                    } else {
280                        "-inf".to_string()
281                    })
282                } else {
283                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
284                }
285            }),
286        )?;
287        state.serialize_field("proportion", &self.proportion)?;
288        state.end()
289    }
290}
291
292impl<'de> Deserialize<'de> for Bin {
293    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
294    where
295        D: Deserializer<'de>,
296    {
297        #[derive(Deserialize)]
298        #[serde(untagged)]
299        enum NumberOrString {
300            Number(f64),
301            String(String),
302        }
303
304        #[derive(Deserialize)]
305        #[serde(field_identifier, rename_all = "snake_case")]
306        enum Field {
307            Id,
308            LowerLimit,
309            UpperLimit,
310            Proportion,
311        }
312
313        struct BinVisitor;
314
315        impl<'de> Visitor<'de> for BinVisitor {
316            type Value = Bin;
317
318            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
319                formatter.write_str("struct Bin")
320            }
321
322            fn visit_map<V>(self, mut map: V) -> Result<Bin, V::Error>
323            where
324                V: MapAccess<'de>,
325            {
326                let mut id = None;
327                let mut lower_limit = None;
328                let mut upper_limit = None;
329                let mut proportion = None;
330
331                while let Some(key) = map.next_key()? {
332                    match key {
333                        Field::Id => {
334                            id = Some(map.next_value()?);
335                        }
336                        Field::LowerLimit => {
337                            let val: Option<NumberOrString> = map.next_value()?;
338                            lower_limit = Some(val.map(|v| match v {
339                                NumberOrString::String(s) => match s.as_str() {
340                                    "inf" => f64::INFINITY,
341                                    "-inf" => f64::NEG_INFINITY,
342                                    _ => s.parse().unwrap(),
343                                },
344                                NumberOrString::Number(n) => n,
345                            }));
346                        }
347                        Field::UpperLimit => {
348                            let val: Option<NumberOrString> = map.next_value()?;
349                            upper_limit = Some(val.map(|v| match v {
350                                NumberOrString::String(s) => match s.as_str() {
351                                    "inf" => f64::INFINITY,
352                                    "-inf" => f64::NEG_INFINITY,
353                                    _ => s.parse().unwrap(),
354                                },
355                                NumberOrString::Number(n) => n,
356                            }));
357                        }
358                        Field::Proportion => {
359                            proportion = Some(map.next_value()?);
360                        }
361                    }
362                }
363
364                Ok(Bin {
365                    id: id.ok_or_else(|| de::Error::missing_field("id"))?,
366                    lower_limit: lower_limit
367                        .ok_or_else(|| de::Error::missing_field("lower_limit"))?,
368                    upper_limit: upper_limit
369                        .ok_or_else(|| de::Error::missing_field("upper_limit"))?,
370                    proportion: proportion.ok_or_else(|| de::Error::missing_field("proportion"))?,
371                })
372            }
373        }
374
375        const FIELDS: &[&str] = &["id", "lower_limit", "upper_limit", "proportion"];
376        deserializer.deserialize_struct("Bin", FIELDS, BinVisitor)
377    }
378}
379
380#[pyclass]
381#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
382pub struct PsiFeatureDriftProfile {
383    #[pyo3(get)]
384    pub id: String,
385
386    #[pyo3(get)]
387    pub bins: Vec<Bin>,
388
389    #[pyo3(get)]
390    pub timestamp: chrono::DateTime<Utc>,
391
392    #[pyo3(get)]
393    pub bin_type: BinType,
394}
395
396#[pyclass]
397#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
398pub struct PsiDriftProfile {
399    #[pyo3(get)]
400    pub features: HashMap<String, PsiFeatureDriftProfile>,
401
402    #[pyo3(get)]
403    pub config: PsiDriftConfig,
404
405    #[pyo3(get)]
406    pub scouter_version: String,
407}
408
409impl PsiDriftProfile {
410    pub fn new(features: HashMap<String, PsiFeatureDriftProfile>, config: PsiDriftConfig) -> Self {
411        Self {
412            features,
413            config,
414            scouter_version: scouter_version(),
415        }
416    }
417}
418
419#[pymethods]
420impl PsiDriftProfile {
421    pub fn __str__(&self) -> String {
422        // serialize the struct to a string
423        PyHelperFuncs::__str__(self)
424    }
425
426    pub fn model_dump_json(&self) -> String {
427        // serialize the struct to a string
428        PyHelperFuncs::__json__(self)
429    }
430    // TODO dry this out
431    #[allow(clippy::useless_conversion)]
432    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
433        let json_str = serde_json::to_string(&self)?;
434
435        let json_value: Value = serde_json::from_str(&json_str)?;
436
437        // Create a new Python dictionary
438        let dict = PyDict::new(py);
439
440        // Convert JSON to Python dict
441        json_to_pyobject(py, &json_value, &dict)?;
442
443        // Return the Python dictionary
444        Ok(dict.into())
445    }
446
447    #[staticmethod]
448    pub fn from_file(path: PathBuf) -> Result<PsiDriftProfile, ProfileError> {
449        let file = std::fs::read_to_string(&path)?;
450
451        Ok(serde_json::from_str(&file)?)
452    }
453
454    #[staticmethod]
455    pub fn model_validate(data: &Bound<'_, PyDict>) -> PsiDriftProfile {
456        let json_value = pyobject_to_json(data).unwrap();
457
458        let string = serde_json::to_string(&json_value).unwrap();
459        serde_json::from_str(&string).expect("Failed to load drift profile")
460    }
461
462    #[staticmethod]
463    pub fn model_validate_json(json_string: String) -> PsiDriftProfile {
464        // deserialize the string to a struct
465        serde_json::from_str(&json_string).expect("Failed to load monitor profile")
466    }
467
468    // Convert python dict into a drift profile
469    #[pyo3(signature = (path=None))]
470    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
471        Ok(PyHelperFuncs::save_to_json(
472            self,
473            path,
474            FileName::PsiDriftProfile.to_str(),
475        )?)
476    }
477
478    #[allow(clippy::too_many_arguments)]
479    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None, categorical_features=None, binning_strategy=None))]
480    pub fn update_config_args(
481        &mut self,
482        space: Option<String>,
483        name: Option<String>,
484        version: Option<String>,
485        alert_config: Option<PsiAlertConfig>,
486        categorical_features: Option<Vec<String>>,
487        binning_strategy: Option<&Bound<'_, PyAny>>,
488    ) -> Result<(), TypeError> {
489        self.config.update_config_args(
490            space,
491            name,
492            version,
493            alert_config,
494            categorical_features,
495            binning_strategy,
496        )
497    }
498
499    /// Create a profile request from the profile
500    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
501        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
502            None
503        } else {
504            Some(self.config.version.clone())
505        };
506
507        Ok(ProfileRequest {
508            space: self.config.space.clone(),
509            profile: self.model_dump_json(),
510            drift_type: self.config.drift_type.clone(),
511            version_request: VersionRequest {
512                version,
513                version_type: VersionType::Minor,
514                pre_tag: None,
515                build_tag: None,
516            },
517            active: false,
518            deactivate_others: false,
519        })
520    }
521}
522
523#[pyclass]
524#[derive(Debug, Serialize, Deserialize, Clone)]
525pub struct PsiDriftMap {
526    #[pyo3(get)]
527    pub features: HashMap<String, f64>,
528
529    #[pyo3(get)]
530    pub name: String,
531
532    #[pyo3(get)]
533    pub space: String,
534
535    #[pyo3(get)]
536    pub version: String,
537}
538
539impl PsiDriftMap {
540    pub fn new(space: String, name: String, version: String) -> Self {
541        Self {
542            features: HashMap::new(),
543            name,
544            space,
545            version,
546        }
547    }
548}
549
550#[pymethods]
551#[allow(clippy::new_without_default)]
552impl PsiDriftMap {
553    pub fn __str__(&self) -> String {
554        // serialize the struct to a string
555        PyHelperFuncs::__str__(self)
556    }
557
558    pub fn model_dump_json(&self) -> String {
559        // serialize the struct to a string
560        PyHelperFuncs::__json__(self)
561    }
562
563    #[staticmethod]
564    pub fn model_validate_json(json_string: String) -> Result<PsiDriftMap, ProfileError> {
565        // deserialize the string to a struct
566        Ok(serde_json::from_str(&json_string)?)
567    }
568
569    #[pyo3(signature = (path=None))]
570    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
571        Ok(PyHelperFuncs::save_to_json(
572            self,
573            path,
574            FileName::PsiDriftMap.to_str(),
575        )?)
576    }
577}
578
579// TODO dry this out
580impl ProfileBaseArgs for PsiDriftProfile {
581    /// Get the base arguments for the profile (convenience method on the server)
582    fn get_base_args(&self) -> ProfileArgs {
583        ProfileArgs {
584            name: self.config.name.clone(),
585            space: self.config.space.clone(),
586            version: Some(self.config.version.clone()),
587            schedule: self.config.alert_config.schedule.clone(),
588            scouter_version: self.scouter_version.clone(),
589            drift_type: self.config.drift_type.clone(),
590        }
591    }
592
593    /// Convert the struct to a serde_json::Value
594    fn to_value(&self) -> Value {
595        serde_json::to_value(self).unwrap()
596    }
597}
598
599#[derive(Debug, Clone, Serialize, Deserialize)]
600pub struct DistributionData {
601    pub sample_size: u64,
602    pub bins: BTreeMap<usize, f64>,
603}
604
605#[derive(Debug, Clone, Serialize, Deserialize)]
606pub struct FeatureDistributions {
607    pub distributions: BTreeMap<String, DistributionData>,
608}
609
610impl FeatureDistributions {
611    pub fn is_empty(&self) -> bool {
612        self.distributions.is_empty()
613    }
614}