Skip to main content

scouter_types/psi/
profile.rs

1#![allow(clippy::useless_conversion)]
2use crate::binning::equal_width::EqualWidthBinning;
3use crate::binning::quantile::QuantileBinning;
4use crate::binning::strategy::BinningStrategy;
5use crate::error::{ProfileError, TypeError};
6use crate::psi::alert::PsiAlertConfig;
7use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version, ConfigExt};
8use crate::ProfileRequest;
9use crate::VersionRequest;
10use crate::{
11    DispatchDriftConfig, DriftArgs, DriftType, FeatureMap, FileName, ProfileArgs, ProfileBaseArgs,
12    PyHelperFuncs, DEFAULT_VERSION, MISSING,
13};
14use chrono::Utc;
15use core::fmt::Debug;
16use potato_head::create_uuid7;
17use pyo3::prelude::*;
18use pyo3::types::PyDict;
19use scouter_semver::VersionType;
20use serde::de::{self, MapAccess, Visitor};
21use serde::ser::SerializeStruct;
22use serde::{Deserialize, Deserializer, Serialize, Serializer};
23use serde_json::Value;
24use std::collections::{BTreeMap, HashMap};
25use std::path::PathBuf;
26use tracing::debug;
27
28#[pyclass(eq)]
29#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
30pub enum BinType {
31    Numeric,
32    Category,
33}
34
35#[pyclass]
36#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
37pub struct PsiDriftConfig {
38    #[pyo3(get, set)]
39    pub space: String,
40
41    #[pyo3(get, set)]
42    pub name: String,
43
44    #[pyo3(get, set)]
45    pub version: String,
46
47    #[pyo3(get)]
48    pub uid: String,
49
50    #[pyo3(get, set)]
51    pub alert_config: PsiAlertConfig,
52
53    #[pyo3(get)]
54    #[serde(default)]
55    pub feature_map: FeatureMap,
56
57    #[pyo3(get)]
58    #[serde(default = "default_drift_type")]
59    pub drift_type: DriftType,
60
61    #[pyo3(get, set)]
62    pub categorical_features: Option<Vec<String>>,
63
64    pub binning_strategy: BinningStrategy,
65}
66
67impl ConfigExt for PsiDriftConfig {
68    fn space(&self) -> &str {
69        &self.space
70    }
71
72    fn name(&self) -> &str {
73        &self.name
74    }
75
76    fn version(&self) -> &str {
77        &self.version
78    }
79}
80
81fn default_drift_type() -> DriftType {
82    DriftType::Psi
83}
84
85impl PsiDriftConfig {
86    pub fn update_feature_map(&mut self, feature_map: FeatureMap) {
87        self.feature_map = feature_map;
88    }
89}
90
91#[pymethods]
92#[allow(clippy::too_many_arguments)]
93impl PsiDriftConfig {
94    #[new]
95    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, alert_config=PsiAlertConfig::default(), config_path=None, categorical_features=None, binning_strategy=None))]
96    pub fn new(
97        space: &str,
98        name: &str,
99        version: &str,
100        alert_config: PsiAlertConfig,
101        config_path: Option<PathBuf>,
102        categorical_features: Option<Vec<String>>,
103        binning_strategy: Option<&Bound<'_, PyAny>>,
104    ) -> Result<Self, ProfileError> {
105        if let Some(config_path) = config_path {
106            let config = PsiDriftConfig::load_from_json_file(config_path);
107            return config;
108        }
109
110        let binning_strategy = match binning_strategy {
111            None => BinningStrategy::default(),
112            Some(strategy) => {
113                if strategy.is_instance_of::<QuantileBinning>() {
114                    BinningStrategy::QuantileBinning(strategy.extract()?)
115                } else if strategy.is_instance_of::<EqualWidthBinning>() {
116                    BinningStrategy::EqualWidthBinning(strategy.extract()?)
117                } else {
118                    return Err(ProfileError::InvalidBinningStrategyError);
119                }
120            }
121        };
122
123        if name == MISSING || space == MISSING {
124            debug!("Name and space were not provided. Defaulting to __missing__");
125        }
126
127        Ok(Self {
128            name: name.to_string(),
129            space: space.to_string(),
130            version: version.to_string(),
131            uid: create_uuid7(),
132            alert_config,
133            categorical_features,
134            feature_map: FeatureMap::default(),
135            drift_type: DriftType::Psi,
136            binning_strategy,
137        })
138    }
139
140    #[staticmethod]
141    pub fn load_from_json_file(path: PathBuf) -> Result<PsiDriftConfig, ProfileError> {
142        // deserialize the string to a struct
143
144        let file = std::fs::read_to_string(&path)?;
145
146        Ok(serde_json::from_str(&file)?)
147    }
148
149    pub fn __str__(&self) -> String {
150        // serialize the struct to a string
151        PyHelperFuncs::__str__(self)
152    }
153
154    pub fn model_dump_json(&self) -> String {
155        // serialize the struct to a string
156        PyHelperFuncs::__json__(self)
157    }
158
159    #[allow(clippy::too_many_arguments)]
160    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None, categorical_features=None, binning_strategy=None))]
161    pub fn update_config_args(
162        &mut self,
163        space: Option<String>,
164        name: Option<String>,
165        version: Option<String>,
166        uid: Option<String>,
167        alert_config: Option<PsiAlertConfig>,
168        categorical_features: Option<Vec<String>>,
169        binning_strategy: Option<&Bound<'_, PyAny>>,
170    ) -> Result<(), TypeError> {
171        if name.is_some() {
172            self.name = name.ok_or(TypeError::MissingNameError)?;
173        }
174
175        if space.is_some() {
176            self.space = space.ok_or(TypeError::MissingSpaceError)?;
177        }
178
179        if version.is_some() {
180            self.version = version.ok_or(TypeError::MissingVersionError)?;
181        }
182
183        if alert_config.is_some() {
184            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
185        }
186
187        if uid.is_some() {
188            self.uid = uid.ok_or(TypeError::MissingUidError)?;
189        }
190
191        if categorical_features.is_some() {
192            self.categorical_features = categorical_features;
193        }
194
195        if let Some(binning_strategy) = binning_strategy {
196            if binning_strategy.is_instance_of::<QuantileBinning>() {
197                self.binning_strategy =
198                    BinningStrategy::QuantileBinning(binning_strategy.extract()?);
199            } else if binning_strategy.is_instance_of::<EqualWidthBinning>() {
200                self.binning_strategy =
201                    BinningStrategy::EqualWidthBinning(binning_strategy.extract()?);
202            } else {
203                return Err(TypeError::InvalidBinningStrategyError);
204            }
205        }
206
207        Ok(())
208    }
209
210    #[getter]
211    pub fn binning_strategy<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
212        self.binning_strategy.strategy(py)
213    }
214
215    #[setter]
216    pub fn set_binning_strategy(&mut self, strategy: &Bound<'_, PyAny>) -> PyResult<()> {
217        if strategy.is_instance_of::<QuantileBinning>() {
218            self.binning_strategy = BinningStrategy::QuantileBinning(strategy.extract()?);
219        } else if strategy.is_instance_of::<EqualWidthBinning>() {
220            self.binning_strategy = BinningStrategy::EqualWidthBinning(strategy.extract()?);
221        } else {
222            return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
223                "Invalid binning strategy type",
224            ));
225        }
226        Ok(())
227    }
228}
229
230impl Default for PsiDriftConfig {
231    fn default() -> Self {
232        PsiDriftConfig {
233            name: "__missing__".to_string(),
234            space: "__missing__".to_string(),
235            version: DEFAULT_VERSION.to_string(),
236            uid: MISSING.to_string(),
237            feature_map: FeatureMap::default(),
238            alert_config: PsiAlertConfig::default(),
239            drift_type: DriftType::Psi,
240            categorical_features: None,
241            binning_strategy: BinningStrategy::default(),
242        }
243    }
244}
245// TODO dry this out
246
247impl DispatchDriftConfig for PsiDriftConfig {
248    fn get_drift_args(&self) -> DriftArgs {
249        DriftArgs {
250            name: self.name.clone(),
251            space: self.space.clone(),
252            version: self.version.clone(),
253            dispatch_config: self.alert_config.dispatch_config.clone(),
254        }
255    }
256}
257
258#[pyclass]
259#[derive(Debug, Clone, PartialEq)]
260pub struct Bin {
261    #[pyo3(get)]
262    pub id: i32,
263
264    #[pyo3(get)]
265    pub lower_limit: Option<f64>,
266
267    #[pyo3(get)]
268    pub upper_limit: Option<f64>,
269
270    #[pyo3(get)]
271    pub proportion: f64,
272}
273
274impl Serialize for Bin {
275    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
276    where
277        S: Serializer,
278    {
279        let mut state = serializer.serialize_struct("Bin", 4)?;
280        state.serialize_field("id", &self.id)?;
281
282        state.serialize_field(
283            "lower_limit",
284            &self.lower_limit.map(|v| {
285                if v.is_infinite() {
286                    serde_json::Value::String(if v.is_sign_positive() {
287                        "inf".to_string()
288                    } else {
289                        "-inf".to_string()
290                    })
291                } else {
292                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
293                }
294            }),
295        )?;
296        state.serialize_field(
297            "upper_limit",
298            &self.upper_limit.map(|v| {
299                if v.is_infinite() {
300                    serde_json::Value::String(if v.is_sign_positive() {
301                        "inf".to_string()
302                    } else {
303                        "-inf".to_string()
304                    })
305                } else {
306                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
307                }
308            }),
309        )?;
310        state.serialize_field("proportion", &self.proportion)?;
311        state.end()
312    }
313}
314
315impl<'de> Deserialize<'de> for Bin {
316    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
317    where
318        D: Deserializer<'de>,
319    {
320        #[derive(Deserialize)]
321        #[serde(untagged)]
322        enum NumberOrString {
323            Number(f64),
324            String(String),
325        }
326
327        #[derive(Deserialize)]
328        #[serde(field_identifier, rename_all = "snake_case")]
329        enum Field {
330            Id,
331            LowerLimit,
332            UpperLimit,
333            Proportion,
334        }
335
336        struct BinVisitor;
337
338        impl<'de> Visitor<'de> for BinVisitor {
339            type Value = Bin;
340
341            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
342                formatter.write_str("struct Bin")
343            }
344
345            fn visit_map<V>(self, mut map: V) -> Result<Bin, V::Error>
346            where
347                V: MapAccess<'de>,
348            {
349                let mut id = None;
350                let mut lower_limit = None;
351                let mut upper_limit = None;
352                let mut proportion = None;
353
354                while let Some(key) = map.next_key()? {
355                    match key {
356                        Field::Id => {
357                            id = Some(map.next_value()?);
358                        }
359                        Field::LowerLimit => {
360                            let val: Option<NumberOrString> = map.next_value()?;
361                            lower_limit = Some(val.map(|v| match v {
362                                NumberOrString::String(s) => match s.as_str() {
363                                    "inf" => f64::INFINITY,
364                                    "-inf" => f64::NEG_INFINITY,
365                                    _ => s.parse().unwrap(),
366                                },
367                                NumberOrString::Number(n) => n,
368                            }));
369                        }
370                        Field::UpperLimit => {
371                            let val: Option<NumberOrString> = map.next_value()?;
372                            upper_limit = Some(val.map(|v| match v {
373                                NumberOrString::String(s) => match s.as_str() {
374                                    "inf" => f64::INFINITY,
375                                    "-inf" => f64::NEG_INFINITY,
376                                    _ => s.parse().unwrap(),
377                                },
378                                NumberOrString::Number(n) => n,
379                            }));
380                        }
381                        Field::Proportion => {
382                            proportion = Some(map.next_value()?);
383                        }
384                    }
385                }
386
387                Ok(Bin {
388                    id: id.ok_or_else(|| de::Error::missing_field("id"))?,
389                    lower_limit: lower_limit
390                        .ok_or_else(|| de::Error::missing_field("lower_limit"))?,
391                    upper_limit: upper_limit
392                        .ok_or_else(|| de::Error::missing_field("upper_limit"))?,
393                    proportion: proportion.ok_or_else(|| de::Error::missing_field("proportion"))?,
394                })
395            }
396        }
397
398        const FIELDS: &[&str] = &["id", "lower_limit", "upper_limit", "proportion"];
399        deserializer.deserialize_struct("Bin", FIELDS, BinVisitor)
400    }
401}
402
403#[pyclass]
404#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
405pub struct PsiFeatureDriftProfile {
406    #[pyo3(get)]
407    pub id: String,
408
409    #[pyo3(get)]
410    pub bins: Vec<Bin>,
411
412    #[pyo3(get)]
413    pub timestamp: chrono::DateTime<Utc>,
414
415    #[pyo3(get)]
416    pub bin_type: BinType,
417}
418
419#[pyclass]
420#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
421pub struct PsiDriftProfile {
422    #[pyo3(get)]
423    pub features: HashMap<String, PsiFeatureDriftProfile>,
424
425    #[pyo3(get)]
426    pub config: PsiDriftConfig,
427
428    #[pyo3(get)]
429    pub scouter_version: String,
430}
431
432impl PsiDriftProfile {
433    pub fn new(features: HashMap<String, PsiFeatureDriftProfile>, config: PsiDriftConfig) -> Self {
434        Self {
435            features,
436            config,
437            scouter_version: scouter_version(),
438        }
439    }
440}
441
442#[pymethods]
443impl PsiDriftProfile {
444    pub fn __str__(&self) -> String {
445        // serialize the struct to a string
446        PyHelperFuncs::__str__(self)
447    }
448
449    #[getter]
450    pub fn uid(&self) -> String {
451        self.config.uid.clone()
452    }
453
454    #[setter]
455    pub fn set_uid(&mut self, uid: String) {
456        self.config.uid = uid;
457    }
458
459    pub fn model_dump_json(&self) -> String {
460        // serialize the struct to a string
461        PyHelperFuncs::__json__(self)
462    }
463    // TODO dry this out
464    #[allow(clippy::useless_conversion)]
465    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
466        let json_str = serde_json::to_string(&self)?;
467
468        let json_value: Value = serde_json::from_str(&json_str)?;
469
470        // Create a new Python dictionary
471        let dict = PyDict::new(py);
472
473        // Convert JSON to Python dict
474        json_to_pyobject(py, &json_value, &dict)?;
475
476        // Return the Python dictionary
477        Ok(dict.into())
478    }
479
480    #[staticmethod]
481    pub fn from_file(path: PathBuf) -> Result<PsiDriftProfile, ProfileError> {
482        let file = std::fs::read_to_string(&path)?;
483
484        Ok(serde_json::from_str(&file)?)
485    }
486
487    #[staticmethod]
488    pub fn model_validate(data: &Bound<'_, PyDict>) -> PsiDriftProfile {
489        let json_value = pyobject_to_json(data).unwrap();
490
491        let string = serde_json::to_string(&json_value).unwrap();
492        serde_json::from_str(&string).expect("Failed to load drift profile")
493    }
494
495    #[staticmethod]
496    pub fn model_validate_json(json_string: String) -> PsiDriftProfile {
497        // deserialize the string to a struct
498        serde_json::from_str(&json_string).expect("Failed to load monitor profile")
499    }
500
501    // Convert python dict into a drift profile
502    #[pyo3(signature = (path=None))]
503    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
504        Ok(PyHelperFuncs::save_to_json(
505            self,
506            path,
507            FileName::PsiDriftProfile.to_str(),
508        )?)
509    }
510
511    #[allow(clippy::too_many_arguments)]
512    #[pyo3(signature = (space=None, name=None, version=None, uid=None,alert_config=None, categorical_features=None, binning_strategy=None))]
513    pub fn update_config_args(
514        &mut self,
515        space: Option<String>,
516        name: Option<String>,
517        version: Option<String>,
518        uid: Option<String>,
519        alert_config: Option<PsiAlertConfig>,
520        categorical_features: Option<Vec<String>>,
521        binning_strategy: Option<&Bound<'_, PyAny>>,
522    ) -> Result<(), TypeError> {
523        self.config.update_config_args(
524            space,
525            name,
526            version,
527            uid,
528            alert_config,
529            categorical_features,
530            binning_strategy,
531        )
532    }
533
534    /// Create a profile request from the profile
535    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
536        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
537            None
538        } else {
539            Some(self.config.version.clone())
540        };
541
542        Ok(ProfileRequest {
543            space: self.config.space.clone(),
544            profile: self.model_dump_json(),
545            drift_type: self.config.drift_type.clone(),
546            version_request: Some(VersionRequest {
547                version,
548                version_type: VersionType::Minor,
549                pre_tag: None,
550                build_tag: None,
551            }),
552            active: false,
553            deactivate_others: false,
554        })
555    }
556}
557
558#[pyclass]
559#[derive(Debug, Serialize, Deserialize, Clone)]
560pub struct PsiDriftMap {
561    #[pyo3(get)]
562    pub features: HashMap<String, f64>,
563
564    #[pyo3(get)]
565    pub name: String,
566
567    #[pyo3(get)]
568    pub space: String,
569
570    #[pyo3(get)]
571    pub version: String,
572}
573
574impl PsiDriftMap {
575    pub fn new(space: String, name: String, version: String) -> Self {
576        Self {
577            features: HashMap::new(),
578            name,
579            space,
580            version,
581        }
582    }
583}
584
585#[pymethods]
586#[allow(clippy::new_without_default)]
587impl PsiDriftMap {
588    pub fn __str__(&self) -> String {
589        // serialize the struct to a string
590        PyHelperFuncs::__str__(self)
591    }
592
593    pub fn model_dump_json(&self) -> String {
594        // serialize the struct to a string
595        PyHelperFuncs::__json__(self)
596    }
597
598    #[staticmethod]
599    pub fn model_validate_json(json_string: String) -> Result<PsiDriftMap, ProfileError> {
600        // deserialize the string to a struct
601        Ok(serde_json::from_str(&json_string)?)
602    }
603
604    #[pyo3(signature = (path=None))]
605    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
606        Ok(PyHelperFuncs::save_to_json(
607            self,
608            path,
609            FileName::PsiDriftMap.to_str(),
610        )?)
611    }
612}
613
614// TODO dry this out
615impl ProfileBaseArgs for PsiDriftProfile {
616    type Config = PsiDriftConfig;
617
618    fn config(&self) -> &Self::Config {
619        &self.config
620    }
621    /// Get the base arguments for the profile (convenience method on the server)
622    fn get_base_args(&self) -> ProfileArgs {
623        ProfileArgs {
624            name: self.config.name.clone(),
625            space: self.config.space.clone(),
626            version: Some(self.config.version.clone()),
627            schedule: self.config.alert_config.schedule.clone(),
628            scouter_version: self.scouter_version.clone(),
629            drift_type: self.config.drift_type.clone(),
630        }
631    }
632
633    /// Convert the struct to a serde_json::Value
634    fn to_value(&self) -> Value {
635        serde_json::to_value(self).unwrap()
636    }
637}
638
639#[derive(Debug, Clone, Serialize, Deserialize)]
640pub struct DistributionData {
641    pub sample_size: u64,
642    pub bins: BTreeMap<i32, f64>,
643}
644
645#[derive(Debug, Clone, Serialize, Deserialize)]
646pub struct FeatureDistributions {
647    pub distributions: BTreeMap<String, DistributionData>,
648}
649
650impl FeatureDistributions {
651    pub fn is_empty(&self) -> bool {
652        self.distributions.is_empty()
653    }
654}
655
656#[derive(Debug)]
657pub struct FeatureDistributionRow {
658    pub feature: String,
659    pub distribution: DistributionData,
660}
661
662#[cfg(feature = "server")]
663impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for FeatureDistributionRow {
664    fn from_row(row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
665        use sqlx::Row;
666
667        let feature: String = row.try_get("feature")?;
668        let sample_size: i64 = row.try_get("sample_size")?;
669        let bins_json: serde_json::Value = row.try_get("bins")?;
670        let bins: BTreeMap<i32, f64> =
671            serde_json::from_value(bins_json).map_err(|e| sqlx::Error::Decode(Box::new(e)))?;
672
673        Ok(FeatureDistributionRow {
674            feature,
675            distribution: DistributionData {
676                sample_size: sample_size as u64,
677                bins,
678            },
679        })
680    }
681}
682
683impl FeatureDistributions {
684    /// Convert from a vector of rows into the final structure
685    pub fn from_rows(rows: Vec<FeatureDistributionRow>) -> Self {
686        let distributions = rows
687            .into_iter()
688            .map(|row| (row.feature, row.distribution))
689            .collect();
690
691        FeatureDistributions { distributions }
692    }
693}