scouter_types/psi/
profile.rs

1#![allow(clippy::useless_conversion)]
2use crate::error::{ProfileError, TypeError};
3use crate::psi::alert::PsiAlertConfig;
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::ProfileRequest;
6use crate::{
7    DispatchDriftConfig, DriftArgs, DriftType, FeatureMap, FileName, ProfileArgs, ProfileBaseArgs,
8    ProfileFuncs, DEFAULT_VERSION, MISSING,
9};
10use chrono::Utc;
11use core::fmt::Debug;
12use pyo3::prelude::*;
13use pyo3::types::PyDict;
14use serde::de::{self, MapAccess, Visitor};
15use serde::ser::SerializeStruct;
16use serde::{Deserialize, Deserializer, Serialize, Serializer};
17use serde_json::Value;
18use std::collections::{BTreeMap, HashMap};
19use std::path::PathBuf;
20use tracing::debug;
21
22#[pyclass(eq)]
23#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
24pub enum BinType {
25    Numeric,
26    Category,
27}
28
29#[pyclass]
30#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
31pub struct PsiDriftConfig {
32    #[pyo3(get, set)]
33    pub space: String,
34
35    #[pyo3(get, set)]
36    pub name: String,
37
38    #[pyo3(get, set)]
39    pub version: String,
40
41    #[pyo3(get, set)]
42    pub alert_config: PsiAlertConfig,
43
44    #[pyo3(get)]
45    #[serde(default)]
46    pub feature_map: FeatureMap,
47
48    #[pyo3(get, set)]
49    #[serde(default = "default_drift_type")]
50    pub drift_type: DriftType,
51
52    #[pyo3(get, set)]
53    pub categorical_features: Option<Vec<String>>,
54}
55
56fn default_drift_type() -> DriftType {
57    DriftType::Psi
58}
59
60impl PsiDriftConfig {
61    pub fn update_feature_map(&mut self, feature_map: FeatureMap) {
62        self.feature_map = feature_map;
63    }
64}
65
66#[pymethods]
67#[allow(clippy::too_many_arguments)]
68impl PsiDriftConfig {
69    #[new]
70    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, alert_config=PsiAlertConfig::default(), config_path=None, categorical_features=None))]
71    pub fn new(
72        space: &str,
73        name: &str,
74        version: &str,
75        alert_config: PsiAlertConfig,
76        config_path: Option<PathBuf>,
77        categorical_features: Option<Vec<String>>,
78    ) -> Result<Self, ProfileError> {
79        if let Some(config_path) = config_path {
80            let config = PsiDriftConfig::load_from_json_file(config_path);
81            return config;
82        }
83
84        if name == MISSING || space == MISSING {
85            debug!("Name and space were not provided. Defaulting to __missing__");
86        }
87
88        Ok(Self {
89            name: name.to_string(),
90            space: space.to_string(),
91            version: version.to_string(),
92            alert_config,
93            categorical_features,
94            feature_map: FeatureMap::default(),
95            drift_type: DriftType::Psi,
96        })
97    }
98
99    #[staticmethod]
100    pub fn load_from_json_file(path: PathBuf) -> Result<PsiDriftConfig, ProfileError> {
101        // deserialize the string to a struct
102
103        let file = std::fs::read_to_string(&path)?;
104
105        Ok(serde_json::from_str(&file)?)
106    }
107
108    pub fn __str__(&self) -> String {
109        // serialize the struct to a string
110        ProfileFuncs::__str__(self)
111    }
112
113    pub fn model_dump_json(&self) -> String {
114        // serialize the struct to a string
115        ProfileFuncs::__json__(self)
116    }
117
118    #[allow(clippy::too_many_arguments)]
119    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
120    pub fn update_config_args(
121        &mut self,
122        space: Option<String>,
123        name: Option<String>,
124        version: Option<String>,
125        alert_config: Option<PsiAlertConfig>,
126    ) -> Result<(), TypeError> {
127        if name.is_some() {
128            self.name = name.ok_or(TypeError::MissingNameError)?;
129        }
130
131        if space.is_some() {
132            self.space = space.ok_or(TypeError::MissingSpaceError)?;
133        }
134
135        if version.is_some() {
136            self.version = version.ok_or(TypeError::MissingVersionError)?;
137        }
138
139        if alert_config.is_some() {
140            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
141        }
142
143        Ok(())
144    }
145}
146
147impl Default for PsiDriftConfig {
148    fn default() -> Self {
149        PsiDriftConfig {
150            name: "__missing__".to_string(),
151            space: "__missing__".to_string(),
152            version: DEFAULT_VERSION.to_string(),
153            feature_map: FeatureMap::default(),
154            alert_config: PsiAlertConfig::default(),
155            drift_type: DriftType::Psi,
156            categorical_features: None,
157        }
158    }
159}
160// TODO dry this out
161
162impl DispatchDriftConfig for PsiDriftConfig {
163    fn get_drift_args(&self) -> DriftArgs {
164        DriftArgs {
165            name: self.name.clone(),
166            space: self.space.clone(),
167            version: self.version.clone(),
168            dispatch_config: self.alert_config.dispatch_config.clone(),
169        }
170    }
171}
172
173#[pyclass]
174#[derive(Debug, Clone, PartialEq)]
175pub struct Bin {
176    #[pyo3(get)]
177    pub id: usize,
178
179    #[pyo3(get)]
180    pub lower_limit: Option<f64>,
181
182    #[pyo3(get)]
183    pub upper_limit: Option<f64>,
184
185    #[pyo3(get)]
186    pub proportion: f64,
187}
188
189impl Serialize for Bin {
190    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
191    where
192        S: Serializer,
193    {
194        let mut state = serializer.serialize_struct("Bin", 4)?;
195        state.serialize_field("id", &self.id)?;
196
197        state.serialize_field(
198            "lower_limit",
199            &self.lower_limit.map(|v| {
200                if v.is_infinite() {
201                    serde_json::Value::String(if v.is_sign_positive() {
202                        "inf".to_string()
203                    } else {
204                        "-inf".to_string()
205                    })
206                } else {
207                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
208                }
209            }),
210        )?;
211        state.serialize_field(
212            "upper_limit",
213            &self.upper_limit.map(|v| {
214                if v.is_infinite() {
215                    serde_json::Value::String(if v.is_sign_positive() {
216                        "inf".to_string()
217                    } else {
218                        "-inf".to_string()
219                    })
220                } else {
221                    serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
222                }
223            }),
224        )?;
225        state.serialize_field("proportion", &self.proportion)?;
226        state.end()
227    }
228}
229
230impl<'de> Deserialize<'de> for Bin {
231    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
232    where
233        D: Deserializer<'de>,
234    {
235        #[derive(Deserialize)]
236        #[serde(untagged)]
237        enum NumberOrString {
238            Number(f64),
239            String(String),
240        }
241
242        #[derive(Deserialize)]
243        #[serde(field_identifier, rename_all = "snake_case")]
244        enum Field {
245            Id,
246            LowerLimit,
247            UpperLimit,
248            Proportion,
249        }
250
251        struct BinVisitor;
252
253        impl<'de> Visitor<'de> for BinVisitor {
254            type Value = Bin;
255
256            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
257                formatter.write_str("struct Bin")
258            }
259
260            fn visit_map<V>(self, mut map: V) -> Result<Bin, V::Error>
261            where
262                V: MapAccess<'de>,
263            {
264                let mut id = None;
265                let mut lower_limit = None;
266                let mut upper_limit = None;
267                let mut proportion = None;
268
269                while let Some(key) = map.next_key()? {
270                    match key {
271                        Field::Id => {
272                            id = Some(map.next_value()?);
273                        }
274                        Field::LowerLimit => {
275                            let val: Option<NumberOrString> = map.next_value()?;
276                            lower_limit = Some(val.map(|v| match v {
277                                NumberOrString::String(s) => match s.as_str() {
278                                    "inf" => f64::INFINITY,
279                                    "-inf" => f64::NEG_INFINITY,
280                                    _ => s.parse().unwrap(),
281                                },
282                                NumberOrString::Number(n) => n,
283                            }));
284                        }
285                        Field::UpperLimit => {
286                            let val: Option<NumberOrString> = map.next_value()?;
287                            upper_limit = Some(val.map(|v| match v {
288                                NumberOrString::String(s) => match s.as_str() {
289                                    "inf" => f64::INFINITY,
290                                    "-inf" => f64::NEG_INFINITY,
291                                    _ => s.parse().unwrap(),
292                                },
293                                NumberOrString::Number(n) => n,
294                            }));
295                        }
296                        Field::Proportion => {
297                            proportion = Some(map.next_value()?);
298                        }
299                    }
300                }
301
302                Ok(Bin {
303                    id: id.ok_or_else(|| de::Error::missing_field("id"))?,
304                    lower_limit: lower_limit
305                        .ok_or_else(|| de::Error::missing_field("lower_limit"))?,
306                    upper_limit: upper_limit
307                        .ok_or_else(|| de::Error::missing_field("upper_limit"))?,
308                    proportion: proportion.ok_or_else(|| de::Error::missing_field("proportion"))?,
309                })
310            }
311        }
312
313        const FIELDS: &[&str] = &["id", "lower_limit", "upper_limit", "proportion"];
314        deserializer.deserialize_struct("Bin", FIELDS, BinVisitor)
315    }
316}
317
318#[pyclass]
319#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
320pub struct PsiFeatureDriftProfile {
321    #[pyo3(get)]
322    pub id: String,
323
324    #[pyo3(get)]
325    pub bins: Vec<Bin>,
326
327    #[pyo3(get)]
328    pub timestamp: chrono::DateTime<Utc>,
329
330    #[pyo3(get)]
331    pub bin_type: BinType,
332}
333
334#[pyclass]
335#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
336pub struct PsiDriftProfile {
337    #[pyo3(get)]
338    pub features: HashMap<String, PsiFeatureDriftProfile>,
339
340    #[pyo3(get)]
341    pub config: PsiDriftConfig,
342
343    #[pyo3(get)]
344    pub scouter_version: String,
345}
346
347impl PsiDriftProfile {
348    pub fn new(
349        features: HashMap<String, PsiFeatureDriftProfile>,
350        config: PsiDriftConfig,
351        scouter_version: Option<String>,
352    ) -> Self {
353        let scouter_version = scouter_version.unwrap_or(env!("CARGO_PKG_VERSION").to_string());
354        Self {
355            features,
356            config,
357            scouter_version,
358        }
359    }
360}
361
362#[pymethods]
363impl PsiDriftProfile {
364    pub fn __str__(&self) -> String {
365        // serialize the struct to a string
366        ProfileFuncs::__str__(self)
367    }
368
369    pub fn model_dump_json(&self) -> String {
370        // serialize the struct to a string
371        ProfileFuncs::__json__(self)
372    }
373    // TODO dry this out
374    #[allow(clippy::useless_conversion)]
375    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
376        let json_str = serde_json::to_string(&self)?;
377
378        let json_value: Value = serde_json::from_str(&json_str)?;
379
380        // Create a new Python dictionary
381        let dict = PyDict::new(py);
382
383        // Convert JSON to Python dict
384        json_to_pyobject(py, &json_value, &dict)?;
385
386        // Return the Python dictionary
387        Ok(dict.into())
388    }
389
390    #[staticmethod]
391    pub fn from_file(path: PathBuf) -> Result<PsiDriftProfile, ProfileError> {
392        let file = std::fs::read_to_string(&path)?;
393
394        Ok(serde_json::from_str(&file)?)
395    }
396
397    #[staticmethod]
398    pub fn model_validate(data: &Bound<'_, PyDict>) -> PsiDriftProfile {
399        let json_value = pyobject_to_json(data).unwrap();
400
401        let string = serde_json::to_string(&json_value).unwrap();
402        serde_json::from_str(&string).expect("Failed to load drift profile")
403    }
404
405    #[staticmethod]
406    pub fn model_validate_json(json_string: String) -> PsiDriftProfile {
407        // deserialize the string to a struct
408        serde_json::from_str(&json_string).expect("Failed to load monitor profile")
409    }
410
411    // Convert python dict into a drift profile
412    #[pyo3(signature = (path=None))]
413    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
414        Ok(ProfileFuncs::save_to_json(
415            self,
416            path,
417            FileName::PsiDriftProfile.to_str(),
418        )?)
419    }
420
421    #[allow(clippy::too_many_arguments)]
422    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
423    pub fn update_config_args(
424        &mut self,
425        space: Option<String>,
426        name: Option<String>,
427        version: Option<String>,
428        alert_config: Option<PsiAlertConfig>,
429    ) -> Result<(), TypeError> {
430        self.config
431            .update_config_args(space, name, version, alert_config)
432    }
433
434    /// Create a profile request from the profile
435    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
436        Ok(ProfileRequest {
437            space: self.config.space.clone(),
438            profile: self.model_dump_json(),
439            drift_type: self.config.drift_type.clone(),
440        })
441    }
442}
443
444#[pyclass]
445#[derive(Debug, Serialize, Deserialize, Clone)]
446pub struct PsiDriftMap {
447    #[pyo3(get)]
448    pub features: HashMap<String, f64>,
449
450    #[pyo3(get)]
451    pub name: String,
452
453    #[pyo3(get)]
454    pub space: String,
455
456    #[pyo3(get)]
457    pub version: String,
458}
459
460impl PsiDriftMap {
461    pub fn new(space: String, name: String, version: String) -> Self {
462        Self {
463            features: HashMap::new(),
464            name,
465            space,
466            version,
467        }
468    }
469}
470
471#[pymethods]
472#[allow(clippy::new_without_default)]
473impl PsiDriftMap {
474    pub fn __str__(&self) -> String {
475        // serialize the struct to a string
476        ProfileFuncs::__str__(self)
477    }
478
479    pub fn model_dump_json(&self) -> String {
480        // serialize the struct to a string
481        ProfileFuncs::__json__(self)
482    }
483
484    #[staticmethod]
485    pub fn model_validate_json(json_string: String) -> Result<PsiDriftMap, ProfileError> {
486        // deserialize the string to a struct
487        Ok(serde_json::from_str(&json_string)?)
488    }
489
490    #[pyo3(signature = (path=None))]
491    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
492        Ok(ProfileFuncs::save_to_json(
493            self,
494            path,
495            FileName::PsiDriftMap.to_str(),
496        )?)
497    }
498}
499
500// TODO dry this out
501impl ProfileBaseArgs for PsiDriftProfile {
502    /// Get the base arguments for the profile (convenience method on the server)
503    fn get_base_args(&self) -> ProfileArgs {
504        ProfileArgs {
505            name: self.config.name.clone(),
506            space: self.config.space.clone(),
507            version: self.config.version.clone(),
508            schedule: self.config.alert_config.schedule.clone(),
509            scouter_version: self.scouter_version.clone(),
510            drift_type: self.config.drift_type.clone(),
511        }
512    }
513
514    /// Convert the struct to a serde_json::Value
515    fn to_value(&self) -> Value {
516        serde_json::to_value(self).unwrap()
517    }
518}
519
520#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct DistributionData {
522    pub sample_size: u64,
523    pub bins: BTreeMap<usize, f64>,
524}
525
526#[derive(Debug, Clone, Serialize, Deserialize)]
527pub struct FeatureDistributions {
528    pub distributions: BTreeMap<String, DistributionData>,
529}
530
531impl FeatureDistributions {
532    pub fn is_empty(&self) -> bool {
533        self.distributions.is_empty()
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540
541    #[test]
542    fn test_drift_config() {
543        let mut drift_config = PsiDriftConfig::new(
544            MISSING,
545            MISSING,
546            DEFAULT_VERSION,
547            PsiAlertConfig::default(),
548            None,
549            None,
550        )
551        .unwrap();
552        assert_eq!(drift_config.name, "__missing__");
553        assert_eq!(drift_config.space, "__missing__");
554        assert_eq!(drift_config.version, "0.1.0");
555        assert_eq!(drift_config.alert_config, PsiAlertConfig::default());
556
557        // update
558        drift_config
559            .update_config_args(None, Some("test".to_string()), None, None)
560            .unwrap();
561
562        assert_eq!(drift_config.name, "test");
563    }
564}