Skip to main content

scouter_types/psi/
profile.rs

1#![allow(clippy::useless_conversion)]
2use crate::binning::equal_width::EqualWidthBinning;
3use crate::binning::quantile::QuantileBinning;
4use crate::binning::strategy::BinningStrategy;
5use crate::error::{ProfileError, TypeError};
6use crate::psi::alert::PsiAlertConfig;
7use crate::traits::ConfigExt;
8use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version};
9use crate::ProfileRequest;
10use crate::VersionRequest;
11use crate::{
12    DispatchDriftConfig, DriftArgs, DriftType, FeatureMap, FileName, ProfileArgs, ProfileBaseArgs,
13    PyHelperFuncs, DEFAULT_VERSION, MISSING,
14};
15use chrono::Utc;
16use core::fmt::Debug;
17use potato_head::create_uuid7;
18use pyo3::prelude::*;
19use pyo3::types::PyDict;
20use scouter_semver::VersionType;
21use serde::de::{self, MapAccess, Visitor};
22use serde::ser::SerializeStruct;
23use serde::{Deserialize, Deserializer, Serialize, Serializer};
24use serde_json::Value;
25use std::collections::{BTreeMap, HashMap};
26use std::path::PathBuf;
27use tracing::debug;
28
29#[pyclass(eq)]
30#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
31pub enum BinType {
32    Numeric,
33    Category,
34}
35
36#[pyclass]
37#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
38pub struct PsiDriftConfig {
39    #[pyo3(get, set)]
40    pub space: String,
41
42    #[pyo3(get, set)]
43    pub name: String,
44
45    #[pyo3(get, set)]
46    pub version: String,
47
48    #[pyo3(get)]
49    pub uid: String,
50
51    #[pyo3(get, set)]
52    pub alert_config: PsiAlertConfig,
53
54    #[pyo3(get)]
55    #[serde(default)]
56    pub feature_map: FeatureMap,
57
58    #[pyo3(get)]
59    #[serde(default = "default_drift_type")]
60    pub drift_type: DriftType,
61
62    #[pyo3(get, set)]
63    pub categorical_features: Option<Vec<String>>,
64
65    pub binning_strategy: BinningStrategy,
66}
67
68impl ConfigExt for PsiDriftConfig {
69    fn space(&self) -> &str {
70        &self.space
71    }
72
73    fn name(&self) -> &str {
74        &self.name
75    }
76
77    fn version(&self) -> &str {
78        &self.version
79    }
80    fn uid(&self) -> &str {
81        &self.uid
82    }
83}
84
85fn default_drift_type() -> DriftType {
86    DriftType::Psi
87}
88
89impl PsiDriftConfig {
90    pub fn update_feature_map(&mut self, feature_map: FeatureMap) {
91        self.feature_map = feature_map;
92    }
93}
94
95#[pymethods]
96#[allow(clippy::too_many_arguments)]
97impl PsiDriftConfig {
98    #[new]
99    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, alert_config=PsiAlertConfig::default(), config_path=None, categorical_features=None, binning_strategy=None))]
100    pub fn new(
101        space: &str,
102        name: &str,
103        version: &str,
104        alert_config: PsiAlertConfig,
105        config_path: Option<PathBuf>,
106        categorical_features: Option<Vec<String>>,
107        binning_strategy: Option<&Bound<'_, PyAny>>,
108    ) -> Result<Self, ProfileError> {
109        if let Some(config_path) = config_path {
110            let config = PsiDriftConfig::load_from_json_file(config_path);
111            return config;
112        }
113
114        let binning_strategy = match binning_strategy {
115            None => BinningStrategy::default(),
116            Some(strategy) => {
117                if strategy.is_instance_of::<QuantileBinning>() {
118                    BinningStrategy::QuantileBinning(strategy.extract()?)
119                } else if strategy.is_instance_of::<EqualWidthBinning>() {
120                    BinningStrategy::EqualWidthBinning(strategy.extract()?)
121                } else {
122                    return Err(ProfileError::InvalidBinningStrategyError);
123                }
124            }
125        };
126
127        if name == MISSING || space == MISSING {
128            debug!("Name and space were not provided. Defaulting to __missing__");
129        }
130
131        Ok(Self {
132            name: name.to_string(),
133            space: space.to_string(),
134            version: version.to_string(),
135            uid: create_uuid7(),
136            alert_config,
137            categorical_features,
138            feature_map: FeatureMap::default(),
139            drift_type: DriftType::Psi,
140            binning_strategy,
141        })
142    }
143
144    #[staticmethod]
145    pub fn load_from_json_file(path: PathBuf) -> Result<PsiDriftConfig, ProfileError> {
146        // deserialize the string to a struct
147
148        let file = std::fs::read_to_string(&path)?;
149
150        Ok(serde_json::from_str(&file)?)
151    }
152
153    pub fn __str__(&self) -> String {
154        // serialize the struct to a string
155        PyHelperFuncs::__str__(self)
156    }
157
158    pub fn model_dump_json(&self) -> String {
159        // serialize the struct to a string
160        PyHelperFuncs::__json__(self)
161    }
162
163    #[allow(clippy::too_many_arguments)]
164    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None, categorical_features=None, binning_strategy=None))]
165    pub fn update_config_args(
166        &mut self,
167        space: Option<String>,
168        name: Option<String>,
169        version: Option<String>,
170        uid: Option<String>,
171        alert_config: Option<PsiAlertConfig>,
172        categorical_features: Option<Vec<String>>,
173        binning_strategy: Option<&Bound<'_, PyAny>>,
174    ) -> Result<(), TypeError> {
175        if name.is_some() {
176            self.name = name.ok_or(TypeError::MissingNameError)?;
177        }
178
179        if space.is_some() {
180            self.space = space.ok_or(TypeError::MissingSpaceError)?;
181        }
182
183        if version.is_some() {
184            self.version = version.ok_or(TypeError::MissingVersionError)?;
185        }
186
187        if alert_config.is_some() {
188            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
189        }
190
191        if uid.is_some() {
192            self.uid = uid.ok_or(TypeError::MissingUidError)?;
193        }
194
195        if categorical_features.is_some() {
196            self.categorical_features = categorical_features;
197        }
198
199        if let Some(binning_strategy) = binning_strategy {
200            if binning_strategy.is_instance_of::<QuantileBinning>() {
201                self.binning_strategy =
202                    BinningStrategy::QuantileBinning(binning_strategy.extract()?);
203            } else if binning_strategy.is_instance_of::<EqualWidthBinning>() {
204                self.binning_strategy =
205                    BinningStrategy::EqualWidthBinning(binning_strategy.extract()?);
206            } else {
207                return Err(TypeError::InvalidBinningStrategyError);
208            }
209        }
210
211        Ok(())
212    }
213
214    #[getter]
215    pub fn binning_strategy<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
216        self.binning_strategy.strategy(py)
217    }
218
219    #[setter]
220    pub fn set_binning_strategy(&mut self, strategy: &Bound<'_, PyAny>) -> PyResult<()> {
221        if strategy.is_instance_of::<QuantileBinning>() {
222            self.binning_strategy = BinningStrategy::QuantileBinning(strategy.extract()?);
223        } else if strategy.is_instance_of::<EqualWidthBinning>() {
224            self.binning_strategy = BinningStrategy::EqualWidthBinning(strategy.extract()?);
225        } else {
226            return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
227                "Invalid binning strategy type",
228            ));
229        }
230        Ok(())
231    }
232}
233
234impl Default for PsiDriftConfig {
235    fn default() -> Self {
236        PsiDriftConfig {
237            name: "__missing__".to_string(),
238            space: "__missing__".to_string(),
239            version: DEFAULT_VERSION.to_string(),
240            uid: MISSING.to_string(),
241            feature_map: FeatureMap::default(),
242            alert_config: PsiAlertConfig::default(),
243            drift_type: DriftType::Psi,
244            categorical_features: None,
245            binning_strategy: BinningStrategy::default(),
246        }
247    }
248}
249// TODO dry this out
250
251impl DispatchDriftConfig for PsiDriftConfig {
252    fn get_drift_args(&self) -> DriftArgs {
253        DriftArgs {
254            name: self.name.clone(),
255            space: self.space.clone(),
256            version: self.version.clone(),
257            dispatch_config: self.alert_config.dispatch_config.clone(),
258        }
259    }
260}
261
262#[pyclass]
263#[derive(Debug, Clone, PartialEq)]
264pub struct Bin {
265    #[pyo3(get)]
266    pub id: i32,
267
268    #[pyo3(get)]
269    pub lower_limit: Option<f64>,
270
271    #[pyo3(get)]
272    pub upper_limit: Option<f64>,
273
274    #[pyo3(get)]
275    pub proportion: f64,
276}
277
278impl Serialize for Bin {
279    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
280    where
281        S: Serializer,
282    {
283        let mut state = serializer.serialize_struct("Bin", 4)?;
284        state.serialize_field("id", &self.id)?;
285
286        state.serialize_field(
287            "lower_limit",
288            &self.lower_limit.map(|v| {
289                if v.is_infinite() {
290                    serde_json::Value::String(if v.is_sign_positive() {
291                        "inf".to_string()
292                    } else {
293                        "-inf".to_string()
294                    })
295                } else {
296                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
297                }
298            }),
299        )?;
300        state.serialize_field(
301            "upper_limit",
302            &self.upper_limit.map(|v| {
303                if v.is_infinite() {
304                    serde_json::Value::String(if v.is_sign_positive() {
305                        "inf".to_string()
306                    } else {
307                        "-inf".to_string()
308                    })
309                } else {
310                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
311                }
312            }),
313        )?;
314        state.serialize_field("proportion", &self.proportion)?;
315        state.end()
316    }
317}
318
319impl<'de> Deserialize<'de> for Bin {
320    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
321    where
322        D: Deserializer<'de>,
323    {
324        #[derive(Deserialize)]
325        #[serde(untagged)]
326        enum NumberOrString {
327            Number(f64),
328            String(String),
329        }
330
331        #[derive(Deserialize)]
332        #[serde(field_identifier, rename_all = "snake_case")]
333        enum Field {
334            Id,
335            LowerLimit,
336            UpperLimit,
337            Proportion,
338        }
339
340        struct BinVisitor;
341
342        impl<'de> Visitor<'de> for BinVisitor {
343            type Value = Bin;
344
345            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
346                formatter.write_str("struct Bin")
347            }
348
349            fn visit_map<V>(self, mut map: V) -> Result<Bin, V::Error>
350            where
351                V: MapAccess<'de>,
352            {
353                let mut id = None;
354                let mut lower_limit = None;
355                let mut upper_limit = None;
356                let mut proportion = None;
357
358                while let Some(key) = map.next_key()? {
359                    match key {
360                        Field::Id => {
361                            id = Some(map.next_value()?);
362                        }
363                        Field::LowerLimit => {
364                            let val: Option<NumberOrString> = map.next_value()?;
365                            lower_limit = Some(val.map(|v| match v {
366                                NumberOrString::String(s) => match s.as_str() {
367                                    "inf" => f64::INFINITY,
368                                    "-inf" => f64::NEG_INFINITY,
369                                    _ => s.parse().unwrap(),
370                                },
371                                NumberOrString::Number(n) => n,
372                            }));
373                        }
374                        Field::UpperLimit => {
375                            let val: Option<NumberOrString> = map.next_value()?;
376                            upper_limit = Some(val.map(|v| match v {
377                                NumberOrString::String(s) => match s.as_str() {
378                                    "inf" => f64::INFINITY,
379                                    "-inf" => f64::NEG_INFINITY,
380                                    _ => s.parse().unwrap(),
381                                },
382                                NumberOrString::Number(n) => n,
383                            }));
384                        }
385                        Field::Proportion => {
386                            proportion = Some(map.next_value()?);
387                        }
388                    }
389                }
390
391                Ok(Bin {
392                    id: id.ok_or_else(|| de::Error::missing_field("id"))?,
393                    lower_limit: lower_limit
394                        .ok_or_else(|| de::Error::missing_field("lower_limit"))?,
395                    upper_limit: upper_limit
396                        .ok_or_else(|| de::Error::missing_field("upper_limit"))?,
397                    proportion: proportion.ok_or_else(|| de::Error::missing_field("proportion"))?,
398                })
399            }
400        }
401
402        const FIELDS: &[&str] = &["id", "lower_limit", "upper_limit", "proportion"];
403        deserializer.deserialize_struct("Bin", FIELDS, BinVisitor)
404    }
405}
406
407#[pyclass]
408#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
409pub struct PsiFeatureDriftProfile {
410    #[pyo3(get)]
411    pub id: String,
412
413    #[pyo3(get)]
414    pub bins: Vec<Bin>,
415
416    #[pyo3(get)]
417    pub timestamp: chrono::DateTime<Utc>,
418
419    #[pyo3(get)]
420    pub bin_type: BinType,
421}
422
423#[pyclass]
424#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
425pub struct PsiDriftProfile {
426    #[pyo3(get)]
427    pub features: HashMap<String, PsiFeatureDriftProfile>,
428
429    #[pyo3(get)]
430    pub config: PsiDriftConfig,
431
432    #[pyo3(get)]
433    pub scouter_version: String,
434}
435
436impl PsiDriftProfile {
437    pub fn new(features: HashMap<String, PsiFeatureDriftProfile>, config: PsiDriftConfig) -> Self {
438        Self {
439            features,
440            config,
441            scouter_version: scouter_version(),
442        }
443    }
444}
445
446#[pymethods]
447impl PsiDriftProfile {
448    pub fn __str__(&self) -> String {
449        // serialize the struct to a string
450        PyHelperFuncs::__str__(self)
451    }
452
453    #[getter]
454    pub fn uid(&self) -> String {
455        self.config.uid.clone()
456    }
457
458    #[setter]
459    pub fn set_uid(&mut self, uid: String) {
460        self.config.uid = uid;
461    }
462
463    #[getter]
464    pub fn drift_type(&self) -> DriftType {
465        self.config.drift_type.clone()
466    }
467    pub fn model_dump_json(&self) -> String {
468        // serialize the struct to a string
469        PyHelperFuncs::__json__(self)
470    }
471    // TODO dry this out
472    #[allow(clippy::useless_conversion)]
473    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
474        let json_str = serde_json::to_string(&self)?;
475
476        let json_value: Value = serde_json::from_str(&json_str)?;
477
478        // Create a new Python dictionary
479        let dict = PyDict::new(py);
480
481        // Convert JSON to Python dict
482        json_to_pyobject(py, &json_value, &dict)?;
483
484        // Return the Python dictionary
485        Ok(dict.into())
486    }
487
488    #[staticmethod]
489    pub fn from_file(path: PathBuf) -> Result<PsiDriftProfile, ProfileError> {
490        let file = std::fs::read_to_string(&path)?;
491
492        Ok(serde_json::from_str(&file)?)
493    }
494
495    #[staticmethod]
496    pub fn model_validate(data: &Bound<'_, PyDict>) -> PsiDriftProfile {
497        let json_value = pyobject_to_json(data).unwrap();
498
499        let string = serde_json::to_string(&json_value).unwrap();
500        serde_json::from_str(&string).expect("Failed to load drift profile")
501    }
502
503    #[staticmethod]
504    pub fn model_validate_json(json_string: String) -> PsiDriftProfile {
505        // deserialize the string to a struct
506        serde_json::from_str(&json_string).expect("Failed to load monitor profile")
507    }
508
509    // Convert python dict into a drift profile
510    #[pyo3(signature = (path=None))]
511    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
512        Ok(PyHelperFuncs::save_to_json(
513            self,
514            path,
515            FileName::PsiDriftProfile.to_str(),
516        )?)
517    }
518
519    #[allow(clippy::too_many_arguments)]
520    #[pyo3(signature = (space=None, name=None, version=None, uid=None,alert_config=None, categorical_features=None, binning_strategy=None))]
521    pub fn update_config_args(
522        &mut self,
523        space: Option<String>,
524        name: Option<String>,
525        version: Option<String>,
526        uid: Option<String>,
527        alert_config: Option<PsiAlertConfig>,
528        categorical_features: Option<Vec<String>>,
529        binning_strategy: Option<&Bound<'_, PyAny>>,
530    ) -> Result<(), TypeError> {
531        self.config.update_config_args(
532            space,
533            name,
534            version,
535            uid,
536            alert_config,
537            categorical_features,
538            binning_strategy,
539        )
540    }
541
542    /// Create a profile request from the profile
543    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
544        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
545            None
546        } else {
547            Some(self.config.version.clone())
548        };
549
550        Ok(ProfileRequest {
551            space: self.config.space.clone(),
552            profile: self.model_dump_json(),
553            drift_type: self.config.drift_type.clone(),
554            version_request: Some(VersionRequest {
555                version,
556                version_type: VersionType::Minor,
557                pre_tag: None,
558                build_tag: None,
559            }),
560            active: false,
561            deactivate_others: false,
562        })
563    }
564}
565
566#[pyclass]
567#[derive(Debug, Serialize, Deserialize, Clone)]
568pub struct PsiDriftMap {
569    #[pyo3(get)]
570    pub features: HashMap<String, f64>,
571
572    #[pyo3(get)]
573    pub name: String,
574
575    #[pyo3(get)]
576    pub space: String,
577
578    #[pyo3(get)]
579    pub version: String,
580}
581
582impl PsiDriftMap {
583    pub fn new(space: String, name: String, version: String) -> Self {
584        Self {
585            features: HashMap::new(),
586            name,
587            space,
588            version,
589        }
590    }
591}
592
593#[pymethods]
594#[allow(clippy::new_without_default)]
595impl PsiDriftMap {
596    pub fn __str__(&self) -> String {
597        // serialize the struct to a string
598        PyHelperFuncs::__str__(self)
599    }
600
601    pub fn model_dump_json(&self) -> String {
602        // serialize the struct to a string
603        PyHelperFuncs::__json__(self)
604    }
605
606    #[staticmethod]
607    pub fn model_validate_json(json_string: String) -> Result<PsiDriftMap, ProfileError> {
608        // deserialize the string to a struct
609        Ok(serde_json::from_str(&json_string)?)
610    }
611
612    #[pyo3(signature = (path=None))]
613    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
614        Ok(PyHelperFuncs::save_to_json(
615            self,
616            path,
617            FileName::PsiDriftMap.to_str(),
618        )?)
619    }
620}
621
622// TODO dry this out
623impl ProfileBaseArgs for PsiDriftProfile {
624    type Config = PsiDriftConfig;
625
626    fn config(&self) -> &Self::Config {
627        &self.config
628    }
629    /// Get the base arguments for the profile (convenience method on the server)
630    fn get_base_args(&self) -> ProfileArgs {
631        ProfileArgs {
632            name: self.config.name.clone(),
633            space: self.config.space.clone(),
634            version: Some(self.config.version.clone()),
635            schedule: self.config.alert_config.schedule.clone(),
636            scouter_version: self.scouter_version.clone(),
637            drift_type: self.config.drift_type.clone(),
638        }
639    }
640
641    /// Convert the struct to a serde_json::Value
642    fn to_value(&self) -> Value {
643        serde_json::to_value(self).unwrap()
644    }
645}
646
647#[derive(Debug, Clone, Serialize, Deserialize)]
648pub struct DistributionData {
649    pub sample_size: u64,
650    pub bins: BTreeMap<i32, f64>,
651}
652
653#[derive(Debug, Clone, Serialize, Deserialize)]
654pub struct FeatureDistributions {
655    pub distributions: BTreeMap<String, DistributionData>,
656}
657
658impl FeatureDistributions {
659    pub fn is_empty(&self) -> bool {
660        self.distributions.is_empty()
661    }
662}
663
664#[derive(Debug)]
665pub struct FeatureDistributionRow {
666    pub feature: String,
667    pub distribution: DistributionData,
668}
669
670#[cfg(feature = "server")]
671impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for FeatureDistributionRow {
672    fn from_row(row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
673        use sqlx::Row;
674
675        let feature: String = row.try_get("feature")?;
676        let sample_size: i64 = row.try_get("sample_size")?;
677        let bins_json: serde_json::Value = row.try_get("bins")?;
678        let bins: BTreeMap<i32, f64> =
679            serde_json::from_value(bins_json).map_err(|e| sqlx::Error::Decode(Box::new(e)))?;
680
681        Ok(FeatureDistributionRow {
682            feature,
683            distribution: DistributionData {
684                sample_size: sample_size as u64,
685                bins,
686            },
687        })
688    }
689}
690
691impl FeatureDistributions {
692    /// Convert from a vector of rows into the final structure
693    pub fn from_rows(rows: Vec<FeatureDistributionRow>) -> Self {
694        let distributions = rows
695            .into_iter()
696            .map(|row| (row.feature, row.distribution))
697            .collect();
698
699        FeatureDistributions { distributions }
700    }
701}