scouter_types/psi/
profile.rs

1#![allow(clippy::useless_conversion)]
2use crate::error::{ProfileError, TypeError};
3use crate::psi::alert::PsiAlertConfig;
4use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version};
5use crate::ProfileRequest;
6use crate::VersionRequest;
7use crate::{
8    DispatchDriftConfig, DriftArgs, DriftType, FeatureMap, FileName, ProfileArgs, ProfileBaseArgs,
9    ProfileFuncs, DEFAULT_VERSION, MISSING,
10};
11use chrono::Utc;
12use core::fmt::Debug;
13use pyo3::prelude::*;
14use pyo3::types::PyDict;
15use scouter_semver::VersionType;
16use serde::de::{self, MapAccess, Visitor};
17use serde::ser::SerializeStruct;
18use serde::{Deserialize, Deserializer, Serialize, Serializer};
19use serde_json::Value;
20use std::collections::{BTreeMap, HashMap};
21use std::path::PathBuf;
22use tracing::debug;
23
24#[pyclass(eq)]
25#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
26pub enum BinType {
27    Numeric,
28    Category,
29}
30
31#[pyclass]
32#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
33pub struct PsiDriftConfig {
34    #[pyo3(get, set)]
35    pub space: String,
36
37    #[pyo3(get, set)]
38    pub name: String,
39
40    #[pyo3(get, set)]
41    pub version: String,
42
43    #[pyo3(get, set)]
44    pub alert_config: PsiAlertConfig,
45
46    #[pyo3(get)]
47    #[serde(default)]
48    pub feature_map: FeatureMap,
49
50    #[pyo3(get, set)]
51    #[serde(default = "default_drift_type")]
52    pub drift_type: DriftType,
53
54    #[pyo3(get, set)]
55    pub categorical_features: Option<Vec<String>>,
56}
57
58fn default_drift_type() -> DriftType {
59    DriftType::Psi
60}
61
62impl PsiDriftConfig {
63    pub fn update_feature_map(&mut self, feature_map: FeatureMap) {
64        self.feature_map = feature_map;
65    }
66}
67
68#[pymethods]
69#[allow(clippy::too_many_arguments)]
70impl PsiDriftConfig {
71    #[new]
72    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, alert_config=PsiAlertConfig::default(), config_path=None, categorical_features=None))]
73    pub fn new(
74        space: &str,
75        name: &str,
76        version: &str,
77        alert_config: PsiAlertConfig,
78        config_path: Option<PathBuf>,
79        categorical_features: Option<Vec<String>>,
80    ) -> Result<Self, ProfileError> {
81        if let Some(config_path) = config_path {
82            let config = PsiDriftConfig::load_from_json_file(config_path);
83            return config;
84        }
85
86        if name == MISSING || space == MISSING {
87            debug!("Name and space were not provided. Defaulting to __missing__");
88        }
89
90        Ok(Self {
91            name: name.to_string(),
92            space: space.to_string(),
93            version: version.to_string(),
94            alert_config,
95            categorical_features,
96            feature_map: FeatureMap::default(),
97            drift_type: DriftType::Psi,
98        })
99    }
100
101    #[staticmethod]
102    pub fn load_from_json_file(path: PathBuf) -> Result<PsiDriftConfig, ProfileError> {
103        // deserialize the string to a struct
104
105        let file = std::fs::read_to_string(&path)?;
106
107        Ok(serde_json::from_str(&file)?)
108    }
109
110    pub fn __str__(&self) -> String {
111        // serialize the struct to a string
112        ProfileFuncs::__str__(self)
113    }
114
115    pub fn model_dump_json(&self) -> String {
116        // serialize the struct to a string
117        ProfileFuncs::__json__(self)
118    }
119
120    #[allow(clippy::too_many_arguments)]
121    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
122    pub fn update_config_args(
123        &mut self,
124        space: Option<String>,
125        name: Option<String>,
126        version: Option<String>,
127        alert_config: Option<PsiAlertConfig>,
128    ) -> Result<(), TypeError> {
129        if name.is_some() {
130            self.name = name.ok_or(TypeError::MissingNameError)?;
131        }
132
133        if space.is_some() {
134            self.space = space.ok_or(TypeError::MissingSpaceError)?;
135        }
136
137        if version.is_some() {
138            self.version = version.ok_or(TypeError::MissingVersionError)?;
139        }
140
141        if alert_config.is_some() {
142            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
143        }
144
145        Ok(())
146    }
147}
148
149impl Default for PsiDriftConfig {
150    fn default() -> Self {
151        PsiDriftConfig {
152            name: "__missing__".to_string(),
153            space: "__missing__".to_string(),
154            version: DEFAULT_VERSION.to_string(),
155            feature_map: FeatureMap::default(),
156            alert_config: PsiAlertConfig::default(),
157            drift_type: DriftType::Psi,
158            categorical_features: None,
159        }
160    }
161}
162// TODO dry this out
163
164impl DispatchDriftConfig for PsiDriftConfig {
165    fn get_drift_args(&self) -> DriftArgs {
166        DriftArgs {
167            name: self.name.clone(),
168            space: self.space.clone(),
169            version: self.version.clone(),
170            dispatch_config: self.alert_config.dispatch_config.clone(),
171        }
172    }
173}
174
175#[pyclass]
176#[derive(Debug, Clone, PartialEq)]
177pub struct Bin {
178    #[pyo3(get)]
179    pub id: usize,
180
181    #[pyo3(get)]
182    pub lower_limit: Option<f64>,
183
184    #[pyo3(get)]
185    pub upper_limit: Option<f64>,
186
187    #[pyo3(get)]
188    pub proportion: f64,
189}
190
191impl Serialize for Bin {
192    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193    where
194        S: Serializer,
195    {
196        let mut state = serializer.serialize_struct("Bin", 4)?;
197        state.serialize_field("id", &self.id)?;
198
199        state.serialize_field(
200            "lower_limit",
201            &self.lower_limit.map(|v| {
202                if v.is_infinite() {
203                    serde_json::Value::String(if v.is_sign_positive() {
204                        "inf".to_string()
205                    } else {
206                        "-inf".to_string()
207                    })
208                } else {
209                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
210                }
211            }),
212        )?;
213        state.serialize_field(
214            "upper_limit",
215            &self.upper_limit.map(|v| {
216                if v.is_infinite() {
217                    serde_json::Value::String(if v.is_sign_positive() {
218                        "inf".to_string()
219                    } else {
220                        "-inf".to_string()
221                    })
222                } else {
223                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
224                }
225            }),
226        )?;
227        state.serialize_field("proportion", &self.proportion)?;
228        state.end()
229    }
230}
231
232impl<'de> Deserialize<'de> for Bin {
233    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
234    where
235        D: Deserializer<'de>,
236    {
237        #[derive(Deserialize)]
238        #[serde(untagged)]
239        enum NumberOrString {
240            Number(f64),
241            String(String),
242        }
243
244        #[derive(Deserialize)]
245        #[serde(field_identifier, rename_all = "snake_case")]
246        enum Field {
247            Id,
248            LowerLimit,
249            UpperLimit,
250            Proportion,
251        }
252
253        struct BinVisitor;
254
255        impl<'de> Visitor<'de> for BinVisitor {
256            type Value = Bin;
257
258            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
259                formatter.write_str("struct Bin")
260            }
261
262            fn visit_map<V>(self, mut map: V) -> Result<Bin, V::Error>
263            where
264                V: MapAccess<'de>,
265            {
266                let mut id = None;
267                let mut lower_limit = None;
268                let mut upper_limit = None;
269                let mut proportion = None;
270
271                while let Some(key) = map.next_key()? {
272                    match key {
273                        Field::Id => {
274                            id = Some(map.next_value()?);
275                        }
276                        Field::LowerLimit => {
277                            let val: Option<NumberOrString> = map.next_value()?;
278                            lower_limit = Some(val.map(|v| match v {
279                                NumberOrString::String(s) => match s.as_str() {
280                                    "inf" => f64::INFINITY,
281                                    "-inf" => f64::NEG_INFINITY,
282                                    _ => s.parse().unwrap(),
283                                },
284                                NumberOrString::Number(n) => n,
285                            }));
286                        }
287                        Field::UpperLimit => {
288                            let val: Option<NumberOrString> = map.next_value()?;
289                            upper_limit = Some(val.map(|v| match v {
290                                NumberOrString::String(s) => match s.as_str() {
291                                    "inf" => f64::INFINITY,
292                                    "-inf" => f64::NEG_INFINITY,
293                                    _ => s.parse().unwrap(),
294                                },
295                                NumberOrString::Number(n) => n,
296                            }));
297                        }
298                        Field::Proportion => {
299                            proportion = Some(map.next_value()?);
300                        }
301                    }
302                }
303
304                Ok(Bin {
305                    id: id.ok_or_else(|| de::Error::missing_field("id"))?,
306                    lower_limit: lower_limit
307                        .ok_or_else(|| de::Error::missing_field("lower_limit"))?,
308                    upper_limit: upper_limit
309                        .ok_or_else(|| de::Error::missing_field("upper_limit"))?,
310                    proportion: proportion.ok_or_else(|| de::Error::missing_field("proportion"))?,
311                })
312            }
313        }
314
315        const FIELDS: &[&str] = &["id", "lower_limit", "upper_limit", "proportion"];
316        deserializer.deserialize_struct("Bin", FIELDS, BinVisitor)
317    }
318}
319
320#[pyclass]
321#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
322pub struct PsiFeatureDriftProfile {
323    #[pyo3(get)]
324    pub id: String,
325
326    #[pyo3(get)]
327    pub bins: Vec<Bin>,
328
329    #[pyo3(get)]
330    pub timestamp: chrono::DateTime<Utc>,
331
332    #[pyo3(get)]
333    pub bin_type: BinType,
334}
335
336#[pyclass]
337#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
338pub struct PsiDriftProfile {
339    #[pyo3(get)]
340    pub features: HashMap<String, PsiFeatureDriftProfile>,
341
342    #[pyo3(get)]
343    pub config: PsiDriftConfig,
344
345    #[pyo3(get)]
346    pub scouter_version: String,
347}
348
349impl PsiDriftProfile {
350    pub fn new(features: HashMap<String, PsiFeatureDriftProfile>, config: PsiDriftConfig) -> Self {
351        Self {
352            features,
353            config,
354            scouter_version: scouter_version(),
355        }
356    }
357}
358
359#[pymethods]
360impl PsiDriftProfile {
361    pub fn __str__(&self) -> String {
362        // serialize the struct to a string
363        ProfileFuncs::__str__(self)
364    }
365
366    pub fn model_dump_json(&self) -> String {
367        // serialize the struct to a string
368        ProfileFuncs::__json__(self)
369    }
370    // TODO dry this out
371    #[allow(clippy::useless_conversion)]
372    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
373        let json_str = serde_json::to_string(&self)?;
374
375        let json_value: Value = serde_json::from_str(&json_str)?;
376
377        // Create a new Python dictionary
378        let dict = PyDict::new(py);
379
380        // Convert JSON to Python dict
381        json_to_pyobject(py, &json_value, &dict)?;
382
383        // Return the Python dictionary
384        Ok(dict.into())
385    }
386
387    #[staticmethod]
388    pub fn from_file(path: PathBuf) -> Result<PsiDriftProfile, ProfileError> {
389        let file = std::fs::read_to_string(&path)?;
390
391        Ok(serde_json::from_str(&file)?)
392    }
393
394    #[staticmethod]
395    pub fn model_validate(data: &Bound<'_, PyDict>) -> PsiDriftProfile {
396        let json_value = pyobject_to_json(data).unwrap();
397
398        let string = serde_json::to_string(&json_value).unwrap();
399        serde_json::from_str(&string).expect("Failed to load drift profile")
400    }
401
402    #[staticmethod]
403    pub fn model_validate_json(json_string: String) -> PsiDriftProfile {
404        // deserialize the string to a struct
405        serde_json::from_str(&json_string).expect("Failed to load monitor profile")
406    }
407
408    // Convert python dict into a drift profile
409    #[pyo3(signature = (path=None))]
410    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
411        Ok(ProfileFuncs::save_to_json(
412            self,
413            path,
414            FileName::PsiDriftProfile.to_str(),
415        )?)
416    }
417
418    #[allow(clippy::too_many_arguments)]
419    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
420    pub fn update_config_args(
421        &mut self,
422        space: Option<String>,
423        name: Option<String>,
424        version: Option<String>,
425        alert_config: Option<PsiAlertConfig>,
426    ) -> Result<(), TypeError> {
427        self.config
428            .update_config_args(space, name, version, alert_config)
429    }
430
431    /// Create a profile request from the profile
432    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
433        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
434            None
435        } else {
436            Some(self.config.version.clone())
437        };
438
439        Ok(ProfileRequest {
440            space: self.config.space.clone(),
441            profile: self.model_dump_json(),
442            drift_type: self.config.drift_type.clone(),
443            version_request: VersionRequest {
444                version,
445                version_type: VersionType::Minor,
446                pre_tag: None,
447                build_tag: None,
448            },
449        })
450    }
451}
452
453#[pyclass]
454#[derive(Debug, Serialize, Deserialize, Clone)]
455pub struct PsiDriftMap {
456    #[pyo3(get)]
457    pub features: HashMap<String, f64>,
458
459    #[pyo3(get)]
460    pub name: String,
461
462    #[pyo3(get)]
463    pub space: String,
464
465    #[pyo3(get)]
466    pub version: String,
467}
468
469impl PsiDriftMap {
470    pub fn new(space: String, name: String, version: String) -> Self {
471        Self {
472            features: HashMap::new(),
473            name,
474            space,
475            version,
476        }
477    }
478}
479
480#[pymethods]
481#[allow(clippy::new_without_default)]
482impl PsiDriftMap {
483    pub fn __str__(&self) -> String {
484        // serialize the struct to a string
485        ProfileFuncs::__str__(self)
486    }
487
488    pub fn model_dump_json(&self) -> String {
489        // serialize the struct to a string
490        ProfileFuncs::__json__(self)
491    }
492
493    #[staticmethod]
494    pub fn model_validate_json(json_string: String) -> Result<PsiDriftMap, ProfileError> {
495        // deserialize the string to a struct
496        Ok(serde_json::from_str(&json_string)?)
497    }
498
499    #[pyo3(signature = (path=None))]
500    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
501        Ok(ProfileFuncs::save_to_json(
502            self,
503            path,
504            FileName::PsiDriftMap.to_str(),
505        )?)
506    }
507}
508
509// TODO dry this out
510impl ProfileBaseArgs for PsiDriftProfile {
511    /// Get the base arguments for the profile (convenience method on the server)
512    fn get_base_args(&self) -> ProfileArgs {
513        ProfileArgs {
514            name: self.config.name.clone(),
515            space: self.config.space.clone(),
516            version: Some(self.config.version.clone()),
517            schedule: self.config.alert_config.schedule.clone(),
518            scouter_version: self.scouter_version.clone(),
519            drift_type: self.config.drift_type.clone(),
520        }
521    }
522
523    /// Convert the struct to a serde_json::Value
524    fn to_value(&self) -> Value {
525        serde_json::to_value(self).unwrap()
526    }
527}
528
529#[derive(Debug, Clone, Serialize, Deserialize)]
530pub struct DistributionData {
531    pub sample_size: u64,
532    pub bins: BTreeMap<usize, f64>,
533}
534
535#[derive(Debug, Clone, Serialize, Deserialize)]
536pub struct FeatureDistributions {
537    pub distributions: BTreeMap<String, DistributionData>,
538}
539
540impl FeatureDistributions {
541    pub fn is_empty(&self) -> bool {
542        self.distributions.is_empty()
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549
550    #[test]
551    fn test_drift_config() {
552        let mut drift_config = PsiDriftConfig::new(
553            MISSING,
554            MISSING,
555            DEFAULT_VERSION,
556            PsiAlertConfig::default(),
557            None,
558            None,
559        )
560        .unwrap();
561        assert_eq!(drift_config.name, "__missing__");
562        assert_eq!(drift_config.space, "__missing__");
563        assert_eq!(drift_config.version, "0.0.0");
564        assert_eq!(drift_config.alert_config, PsiAlertConfig::default());
565
566        // update
567        drift_config
568            .update_config_args(None, Some("test".to_string()), None, None)
569            .unwrap();
570
571        assert_eq!(drift_config.name, "test");
572    }
573}