scouter_types/psi/
profile.rs

1#![allow(clippy::useless_conversion)]
2use crate::error::{ProfileError, TypeError};
3use crate::psi::alert::PsiAlertConfig;
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::ProfileRequest;
6use crate::{
7    DispatchDriftConfig, DriftArgs, DriftType, FeatureMap, FileName, ProfileArgs, ProfileBaseArgs,
8    ProfileFuncs, DEFAULT_VERSION, MISSING,
9};
10use chrono::Utc;
11use core::fmt::Debug;
12use pyo3::prelude::*;
13use pyo3::types::PyDict;
14use serde::de::{self, MapAccess, Visitor};
15use serde::ser::SerializeStruct;
16use serde::{Deserialize, Deserializer, Serialize, Serializer};
17use serde_json::Value;
18use std::collections::{BTreeMap, HashMap};
19use std::path::PathBuf;
20use tracing::debug;
21
22#[pyclass(eq)]
23#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
24pub enum BinType {
25    Binary,
26    Numeric,
27    Category,
28}
29
30#[pyclass]
31#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
32pub struct PsiDriftConfig {
33    #[pyo3(get, set)]
34    pub space: String,
35
36    #[pyo3(get, set)]
37    pub name: String,
38
39    #[pyo3(get, set)]
40    pub version: String,
41
42    #[pyo3(get, set)]
43    pub alert_config: PsiAlertConfig,
44
45    #[pyo3(get)]
46    #[serde(default)]
47    pub feature_map: FeatureMap,
48
49    #[pyo3(get, set)]
50    #[serde(default = "default_drift_type")]
51    pub drift_type: DriftType,
52}
53
54fn default_drift_type() -> DriftType {
55    DriftType::Psi
56}
57
58impl PsiDriftConfig {
59    pub fn update_feature_map(&mut self, feature_map: FeatureMap) {
60        self.feature_map = feature_map;
61    }
62}
63
64#[pymethods]
65#[allow(clippy::too_many_arguments)]
66impl PsiDriftConfig {
67    // TODO dry this out
68    #[new]
69    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, alert_config=PsiAlertConfig::default(), config_path=None))]
70    pub fn new(
71        space: &str,
72        name: &str,
73        version: &str,
74        alert_config: PsiAlertConfig,
75        config_path: Option<PathBuf>,
76    ) -> Result<Self, ProfileError> {
77        if let Some(config_path) = config_path {
78            let config = PsiDriftConfig::load_from_json_file(config_path);
79            return config;
80        }
81
82        if name == MISSING || space == MISSING {
83            debug!("Name and space were not provided. Defaulting to __missing__");
84        }
85
86        Ok(Self {
87            name: name.to_string(),
88            space: space.to_string(),
89            version: version.to_string(),
90            alert_config,
91            feature_map: FeatureMap::default(),
92            drift_type: DriftType::Psi,
93        })
94    }
95
96    #[staticmethod]
97    pub fn load_from_json_file(path: PathBuf) -> Result<PsiDriftConfig, ProfileError> {
98        // deserialize the string to a struct
99
100        let file = std::fs::read_to_string(&path)?;
101
102        Ok(serde_json::from_str(&file)?)
103    }
104
105    pub fn __str__(&self) -> String {
106        // serialize the struct to a string
107        ProfileFuncs::__str__(self)
108    }
109
110    pub fn model_dump_json(&self) -> String {
111        // serialize the struct to a string
112        ProfileFuncs::__json__(self)
113    }
114
115    #[allow(clippy::too_many_arguments)]
116    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
117    pub fn update_config_args(
118        &mut self,
119        space: Option<String>,
120        name: Option<String>,
121        version: Option<String>,
122        alert_config: Option<PsiAlertConfig>,
123    ) -> Result<(), TypeError> {
124        if name.is_some() {
125            self.name = name.ok_or(TypeError::MissingNameError)?;
126        }
127
128        if space.is_some() {
129            self.space = space.ok_or(TypeError::MissingSpaceError)?;
130        }
131
132        if version.is_some() {
133            self.version = version.ok_or(TypeError::MissingVersionError)?;
134        }
135
136        if alert_config.is_some() {
137            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
138        }
139
140        Ok(())
141    }
142}
143
144impl Default for PsiDriftConfig {
145    fn default() -> Self {
146        PsiDriftConfig {
147            name: "__missing__".to_string(),
148            space: "__missing__".to_string(),
149            version: DEFAULT_VERSION.to_string(),
150            feature_map: FeatureMap::default(),
151            alert_config: PsiAlertConfig::default(),
152            drift_type: DriftType::Psi,
153        }
154    }
155}
156// TODO dry this out
157
158impl DispatchDriftConfig for PsiDriftConfig {
159    fn get_drift_args(&self) -> DriftArgs {
160        DriftArgs {
161            name: self.name.clone(),
162            space: self.space.clone(),
163            version: self.version.clone(),
164            dispatch_config: self.alert_config.dispatch_config.clone(),
165        }
166    }
167}
168
169#[pyclass]
170#[derive(Debug, Clone, PartialEq)]
171pub struct Bin {
172    #[pyo3(get)]
173    pub id: usize,
174
175    #[pyo3(get)]
176    pub lower_limit: Option<f64>,
177
178    #[pyo3(get)]
179    pub upper_limit: Option<f64>,
180
181    #[pyo3(get)]
182    pub proportion: f64,
183}
184
185impl Serialize for Bin {
186    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
187    where
188        S: Serializer,
189    {
190        let mut state = serializer.serialize_struct("Bin", 4)?;
191        state.serialize_field("id", &self.id)?;
192
193        state.serialize_field(
194            "lower_limit",
195            &self.lower_limit.map(|v| {
196                if v.is_infinite() {
197                    serde_json::Value::String(if v.is_sign_positive() {
198                        "inf".to_string()
199                    } else {
200                        "-inf".to_string()
201                    })
202                } else {
203                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
204                }
205            }),
206        )?;
207        state.serialize_field(
208            "upper_limit",
209            &self.upper_limit.map(|v| {
210                if v.is_infinite() {
211                    serde_json::Value::String(if v.is_sign_positive() {
212                        "inf".to_string()
213                    } else {
214                        "-inf".to_string()
215                    })
216                } else {
217                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
218                }
219            }),
220        )?;
221        state.serialize_field("proportion", &self.proportion)?;
222        state.end()
223    }
224}
225
226impl<'de> Deserialize<'de> for Bin {
227    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
228    where
229        D: Deserializer<'de>,
230    {
231        #[derive(Deserialize)]
232        #[serde(untagged)]
233        enum NumberOrString {
234            Number(f64),
235            String(String),
236        }
237
238        #[derive(Deserialize)]
239        #[serde(field_identifier, rename_all = "snake_case")]
240        enum Field {
241            Id,
242            LowerLimit,
243            UpperLimit,
244            Proportion,
245        }
246
247        struct BinVisitor;
248
249        impl<'de> Visitor<'de> for BinVisitor {
250            type Value = Bin;
251
252            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
253                formatter.write_str("struct Bin")
254            }
255
256            fn visit_map<V>(self, mut map: V) -> Result<Bin, V::Error>
257            where
258                V: MapAccess<'de>,
259            {
260                let mut id = None;
261                let mut lower_limit = None;
262                let mut upper_limit = None;
263                let mut proportion = None;
264
265                while let Some(key) = map.next_key()? {
266                    match key {
267                        Field::Id => {
268                            id = Some(map.next_value()?);
269                        }
270                        Field::LowerLimit => {
271                            let val: Option<NumberOrString> = map.next_value()?;
272                            lower_limit = Some(val.map(|v| match v {
273                                NumberOrString::String(s) => match s.as_str() {
274                                    "inf" => f64::INFINITY,
275                                    "-inf" => f64::NEG_INFINITY,
276                                    _ => s.parse().unwrap(),
277                                },
278                                NumberOrString::Number(n) => n,
279                            }));
280                        }
281                        Field::UpperLimit => {
282                            let val: Option<NumberOrString> = map.next_value()?;
283                            upper_limit = Some(val.map(|v| match v {
284                                NumberOrString::String(s) => match s.as_str() {
285                                    "inf" => f64::INFINITY,
286                                    "-inf" => f64::NEG_INFINITY,
287                                    _ => s.parse().unwrap(),
288                                },
289                                NumberOrString::Number(n) => n,
290                            }));
291                        }
292                        Field::Proportion => {
293                            proportion = Some(map.next_value()?);
294                        }
295                    }
296                }
297
298                Ok(Bin {
299                    id: id.ok_or_else(|| de::Error::missing_field("id"))?,
300                    lower_limit: lower_limit
301                        .ok_or_else(|| de::Error::missing_field("lower_limit"))?,
302                    upper_limit: upper_limit
303                        .ok_or_else(|| de::Error::missing_field("upper_limit"))?,
304                    proportion: proportion.ok_or_else(|| de::Error::missing_field("proportion"))?,
305                })
306            }
307        }
308
309        const FIELDS: &[&str] = &["id", "lower_limit", "upper_limit", "proportion"];
310        deserializer.deserialize_struct("Bin", FIELDS, BinVisitor)
311    }
312}
313
314#[pyclass]
315#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
316pub struct PsiFeatureDriftProfile {
317    #[pyo3(get)]
318    pub id: String,
319
320    #[pyo3(get)]
321    pub bins: Vec<Bin>,
322
323    #[pyo3(get)]
324    pub timestamp: chrono::DateTime<Utc>,
325
326    #[pyo3(get)]
327    pub bin_type: BinType,
328}
329
330#[pyclass]
331#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
332pub struct PsiDriftProfile {
333    #[pyo3(get)]
334    pub features: HashMap<String, PsiFeatureDriftProfile>,
335
336    #[pyo3(get)]
337    pub config: PsiDriftConfig,
338
339    #[pyo3(get)]
340    pub scouter_version: String,
341}
342
343impl PsiDriftProfile {
344    pub fn new(
345        features: HashMap<String, PsiFeatureDriftProfile>,
346        config: PsiDriftConfig,
347        scouter_version: Option<String>,
348    ) -> Self {
349        let scouter_version = scouter_version.unwrap_or(env!("CARGO_PKG_VERSION").to_string());
350        Self {
351            features,
352            config,
353            scouter_version,
354        }
355    }
356}
357
358#[pymethods]
359impl PsiDriftProfile {
360    pub fn __str__(&self) -> String {
361        // serialize the struct to a string
362        ProfileFuncs::__str__(self)
363    }
364
365    pub fn model_dump_json(&self) -> String {
366        // serialize the struct to a string
367        ProfileFuncs::__json__(self)
368    }
369    // TODO dry this out
370    #[allow(clippy::useless_conversion)]
371    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
372        let json_str = serde_json::to_string(&self)?;
373
374        let json_value: Value = serde_json::from_str(&json_str)?;
375
376        // Create a new Python dictionary
377        let dict = PyDict::new(py);
378
379        // Convert JSON to Python dict
380        json_to_pyobject(py, &json_value, &dict)?;
381
382        // Return the Python dictionary
383        Ok(dict.into())
384    }
385
386    #[staticmethod]
387    pub fn from_file(path: PathBuf) -> Result<PsiDriftProfile, ProfileError> {
388        let file = std::fs::read_to_string(&path)?;
389
390        Ok(serde_json::from_str(&file)?)
391    }
392
393    #[staticmethod]
394    pub fn model_validate(data: &Bound<'_, PyDict>) -> PsiDriftProfile {
395        let json_value = pyobject_to_json(data).unwrap();
396
397        let string = serde_json::to_string(&json_value).unwrap();
398        serde_json::from_str(&string).expect("Failed to load drift profile")
399    }
400
401    #[staticmethod]
402    pub fn model_validate_json(json_string: String) -> PsiDriftProfile {
403        // deserialize the string to a struct
404        serde_json::from_str(&json_string).expect("Failed to load monitor profile")
405    }
406
407    // Convert python dict into a drift profile
408    #[pyo3(signature = (path=None))]
409    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
410        Ok(ProfileFuncs::save_to_json(
411            self,
412            path,
413            FileName::PsiDriftProfile.to_str(),
414        )?)
415    }
416
417    #[allow(clippy::too_many_arguments)]
418    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
419    pub fn update_config_args(
420        &mut self,
421        space: Option<String>,
422        name: Option<String>,
423        version: Option<String>,
424        alert_config: Option<PsiAlertConfig>,
425    ) -> Result<(), TypeError> {
426        self.config
427            .update_config_args(space, name, version, alert_config)
428    }
429
430    /// Create a profile request from the profile
431    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
432        Ok(ProfileRequest {
433            space: self.config.space.clone(),
434            profile: self.model_dump_json(),
435            drift_type: self.config.drift_type.clone(),
436        })
437    }
438}
439
440#[pyclass]
441#[derive(Debug, Serialize, Deserialize, Clone)]
442pub struct PsiDriftMap {
443    #[pyo3(get)]
444    pub features: HashMap<String, f64>,
445
446    #[pyo3(get)]
447    pub name: String,
448
449    #[pyo3(get)]
450    pub space: String,
451
452    #[pyo3(get)]
453    pub version: String,
454}
455
456impl PsiDriftMap {
457    pub fn new(space: String, name: String, version: String) -> Self {
458        Self {
459            features: HashMap::new(),
460            name,
461            space,
462            version,
463        }
464    }
465}
466
467#[pymethods]
468#[allow(clippy::new_without_default)]
469impl PsiDriftMap {
470    pub fn __str__(&self) -> String {
471        // serialize the struct to a string
472        ProfileFuncs::__str__(self)
473    }
474
475    pub fn model_dump_json(&self) -> String {
476        // serialize the struct to a string
477        ProfileFuncs::__json__(self)
478    }
479
480    #[staticmethod]
481    pub fn model_validate_json(json_string: String) -> Result<PsiDriftMap, ProfileError> {
482        // deserialize the string to a struct
483        Ok(serde_json::from_str(&json_string)?)
484    }
485
486    #[pyo3(signature = (path=None))]
487    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
488        Ok(ProfileFuncs::save_to_json(
489            self,
490            path,
491            FileName::PsiDriftMap.to_str(),
492        )?)
493    }
494}
495
496// TODO dry this out
497impl ProfileBaseArgs for PsiDriftProfile {
498    /// Get the base arguments for the profile (convenience method on the server)
499    fn get_base_args(&self) -> ProfileArgs {
500        ProfileArgs {
501            name: self.config.name.clone(),
502            space: self.config.space.clone(),
503            version: self.config.version.clone(),
504            schedule: self.config.alert_config.schedule.clone(),
505            scouter_version: self.scouter_version.clone(),
506            drift_type: self.config.drift_type.clone(),
507        }
508    }
509
510    /// Convert the struct to a serde_json::Value
511    fn to_value(&self) -> Value {
512        serde_json::to_value(self).unwrap()
513    }
514}
515
516#[derive(Debug, Clone, Serialize, Deserialize)]
517pub struct FeatureBinProportion {
518    pub feature: String,
519    pub bins: BTreeMap<usize, f64>,
520}
521
522#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct FeatureBinProportions {
524    pub features: BTreeMap<String, BTreeMap<usize, f64>>,
525}
526
527impl FromIterator<FeatureBinProportion> for FeatureBinProportions {
528    fn from_iter<T: IntoIterator<Item = FeatureBinProportion>>(iter: T) -> Self {
529        let mut feature_map: BTreeMap<String, BTreeMap<usize, f64>> = BTreeMap::new();
530        for feature in iter {
531            feature_map.insert(feature.feature, feature.bins);
532        }
533        FeatureBinProportions {
534            features: feature_map,
535        }
536    }
537}
538
539impl FeatureBinProportions {
540    pub fn from_features(features: Vec<FeatureBinProportion>) -> Self {
541        let mut feature_map: BTreeMap<String, BTreeMap<usize, f64>> = BTreeMap::new();
542        for feature in features {
543            feature_map.insert(feature.feature, feature.bins);
544        }
545        FeatureBinProportions {
546            features: feature_map,
547        }
548    }
549
550    pub fn get(&self, feature: &str, bin: &usize) -> Option<&f64> {
551        self.features.get(feature).and_then(|f| f.get(bin))
552    }
553
554    pub fn is_empty(&self) -> bool {
555        self.features.is_empty()
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn test_drift_config() {
565        let mut drift_config = PsiDriftConfig::new(
566            MISSING,
567            MISSING,
568            DEFAULT_VERSION,
569            PsiAlertConfig::default(),
570            None,
571        )
572        .unwrap();
573        assert_eq!(drift_config.name, "__missing__");
574        assert_eq!(drift_config.space, "__missing__");
575        assert_eq!(drift_config.version, "0.1.0");
576        assert_eq!(drift_config.alert_config, PsiAlertConfig::default());
577
578        // update
579        drift_config
580            .update_config_args(None, Some("test".to_string()), None, None)
581            .unwrap();
582
583        assert_eq!(drift_config.name, "test");
584    }
585}