scouter_types/custom/
profile.rs

1#![allow(clippy::useless_conversion)]
2use crate::custom::alert::{CustomMetric, CustomMetricAlertConfig};
3use crate::error::{ProfileError, TypeError};
4use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version};
5use crate::{ConfigExt, ProfileRequest};
6use crate::{
7    DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
8    PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
9};
10use core::fmt::Debug;
11use potato_head::create_uuid7;
12use pyo3::prelude::*;
13use pyo3::types::PyDict;
14use scouter_semver::VersionType;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::path::PathBuf;
19
20#[pyclass]
21#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
22pub struct CustomMetricDriftConfig {
23    #[pyo3(get, set)]
24    pub sample_size: usize,
25
26    #[pyo3(get, set)]
27    pub space: String,
28
29    #[pyo3(get, set)]
30    pub name: String,
31
32    #[pyo3(get, set)]
33    pub version: String,
34
35    #[pyo3(get, set)]
36    pub uid: String,
37
38    #[pyo3(get, set)]
39    pub alert_config: CustomMetricAlertConfig,
40
41    #[pyo3(get, set)]
42    #[serde(default = "default_drift_type")]
43    pub drift_type: DriftType,
44}
45
46impl ConfigExt for CustomMetricDriftConfig {
47    fn space(&self) -> &str {
48        &self.space
49    }
50
51    fn name(&self) -> &str {
52        &self.name
53    }
54
55    fn version(&self) -> &str {
56        &self.version
57    }
58}
59
60fn default_drift_type() -> DriftType {
61    DriftType::Custom
62}
63
64impl DispatchDriftConfig for CustomMetricDriftConfig {
65    fn get_drift_args(&self) -> DriftArgs {
66        DriftArgs {
67            name: self.name.clone(),
68            space: self.space.clone(),
69            version: self.version.clone(),
70            dispatch_config: self.alert_config.dispatch_config.clone(),
71        }
72    }
73}
74
75#[pymethods]
76#[allow(clippy::too_many_arguments)]
77impl CustomMetricDriftConfig {
78    #[new]
79    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_size=25, alert_config=CustomMetricAlertConfig::default(), config_path=None))]
80    pub fn new(
81        space: &str,
82        name: &str,
83        version: &str,
84        sample_size: usize,
85        alert_config: CustomMetricAlertConfig,
86        config_path: Option<PathBuf>,
87    ) -> Result<Self, ProfileError> {
88        if let Some(config_path) = config_path {
89            let config = CustomMetricDriftConfig::load_from_json_file(config_path)?;
90            return Ok(config);
91        }
92
93        Ok(Self {
94            sample_size,
95            space: space.to_string(),
96            name: name.to_string(),
97            version: version.to_string(),
98            uid: create_uuid7(),
99            alert_config,
100            drift_type: DriftType::Custom,
101        })
102    }
103
104    #[staticmethod]
105    pub fn load_from_json_file(path: PathBuf) -> Result<CustomMetricDriftConfig, ProfileError> {
106        // deserialize the string to a struct
107
108        let file = std::fs::read_to_string(&path)?;
109
110        Ok(serde_json::from_str(&file)?)
111    }
112
113    pub fn __str__(&self) -> String {
114        // serialize the struct to a string
115        PyHelperFuncs::__str__(self)
116    }
117
118    pub fn model_dump_json(&self) -> String {
119        // serialize the struct to a string
120        PyHelperFuncs::__json__(self)
121    }
122
123    #[allow(clippy::too_many_arguments)]
124    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
125    pub fn update_config_args(
126        &mut self,
127        space: Option<String>,
128        name: Option<String>,
129        version: Option<String>,
130        uid: Option<String>,
131        alert_config: Option<CustomMetricAlertConfig>,
132    ) -> Result<(), TypeError> {
133        if name.is_some() {
134            self.name = name.ok_or(TypeError::MissingNameError)?;
135        }
136
137        if space.is_some() {
138            self.space = space.ok_or(TypeError::MissingSpaceError)?;
139        }
140
141        if version.is_some() {
142            self.version = version.ok_or(TypeError::MissingVersionError)?;
143        }
144
145        if alert_config.is_some() {
146            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
147        }
148
149        if uid.is_some() {
150            self.uid = uid.ok_or(TypeError::MissingUidError)?;
151        }
152
153        Ok(())
154    }
155}
156
157#[pyclass]
158#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
159pub struct CustomDriftProfile {
160    #[pyo3(get)]
161    pub config: CustomMetricDriftConfig,
162
163    #[pyo3(get)]
164    pub metrics: HashMap<String, f64>,
165
166    #[pyo3(get)]
167    pub scouter_version: String,
168}
169
170impl Default for CustomDriftProfile {
171    fn default() -> Self {
172        Self {
173            config: CustomMetricDriftConfig::new(
174                MISSING,
175                MISSING,
176                DEFAULT_VERSION,
177                25,
178                CustomMetricAlertConfig::default(),
179                None,
180            )
181            .unwrap(),
182            metrics: HashMap::new(),
183            scouter_version: scouter_version(),
184        }
185    }
186}
187
188#[pymethods]
189impl CustomDriftProfile {
190    #[new]
191    #[pyo3(signature = (config, metrics))]
192    pub fn new(
193        mut config: CustomMetricDriftConfig,
194        metrics: Vec<CustomMetric>,
195    ) -> Result<Self, ProfileError> {
196        if metrics.is_empty() {
197            return Err(TypeError::NoMetricsError.into());
198        }
199
200        config.alert_config.set_alert_conditions(&metrics);
201
202        let metric_vals = metrics
203            .iter()
204            .map(|m| (m.name.clone(), m.baseline_value))
205            .collect();
206
207        Ok(Self {
208            config,
209            metrics: metric_vals,
210            scouter_version: scouter_version(),
211        })
212    }
213
214    pub fn __str__(&self) -> String {
215        // serialize the struct to a string
216        PyHelperFuncs::__str__(self)
217    }
218
219    pub fn model_dump_json(&self) -> String {
220        // serialize the struct to a string
221        PyHelperFuncs::__json__(self)
222    }
223
224    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
225        let json_str = serde_json::to_string(&self)?;
226
227        let json_value: Value = serde_json::from_str(&json_str)?;
228
229        // Create a new Python dictionary
230        let dict = PyDict::new(py);
231
232        // Convert JSON to Python dict
233        json_to_pyobject(py, &json_value, &dict)?;
234
235        // Return the Python dictionary
236        Ok(dict.into())
237    }
238
239    // Convert python dict into a drift profile
240    #[pyo3(signature = (path=None))]
241    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
242        Ok(PyHelperFuncs::save_to_json(
243            self,
244            path,
245            FileName::CustomDriftProfile.to_str(),
246        )?)
247    }
248
249    #[getter]
250    pub fn uid(&self) -> String {
251        self.config.uid.clone()
252    }
253
254    #[setter]
255    pub fn set_uid(&mut self, uid: String) {
256        self.config.uid = uid;
257    }
258
259    #[staticmethod]
260    pub fn model_validate(data: &Bound<'_, PyDict>) -> CustomDriftProfile {
261        let json_value = pyobject_to_json(data).unwrap();
262
263        let string = serde_json::to_string(&json_value).unwrap();
264        serde_json::from_str(&string).expect("Failed to load drift profile")
265    }
266
267    #[staticmethod]
268    pub fn model_validate_json(json_string: String) -> CustomDriftProfile {
269        // deserialize the string to a struct
270        serde_json::from_str(&json_string).expect("Failed to load monitor profile")
271    }
272
273    #[staticmethod]
274    pub fn from_file(path: PathBuf) -> Result<CustomDriftProfile, ProfileError> {
275        let file = std::fs::read_to_string(&path)?;
276
277        Ok(serde_json::from_str(&file)?)
278    }
279
280    #[allow(clippy::too_many_arguments)]
281    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
282    pub fn update_config_args(
283        &mut self,
284        space: Option<String>,
285        name: Option<String>,
286        version: Option<String>,
287        uid: Option<String>,
288        alert_config: Option<CustomMetricAlertConfig>,
289    ) -> Result<(), TypeError> {
290        self.config
291            .update_config_args(space, name, version, uid, alert_config)
292    }
293
294    #[getter]
295    pub fn custom_metrics(&self) -> Result<Vec<CustomMetric>, ProfileError> {
296        let alert_conditions = &self
297            .config
298            .alert_config
299            .alert_conditions
300            .clone()
301            .ok_or(ProfileError::CustomThresholdNotSetError)?;
302
303        Ok(self
304            .metrics
305            .iter()
306            .map(|(name, value)| {
307                // get the alert threshold for the metric
308                let alert = alert_conditions
309                    .get(name)
310                    .ok_or(ProfileError::CustomAlertThresholdNotFound)
311                    .unwrap();
312                CustomMetric::new(name, *value, alert.alert_threshold.clone(), alert.delta).unwrap()
313            })
314            .collect())
315    }
316
317    /// Create a profile request from the profile
318    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
319        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
320            None
321        } else {
322            Some(self.config.version.clone())
323        };
324
325        Ok(ProfileRequest {
326            space: self.config.space.clone(),
327            profile: self.model_dump_json(),
328            drift_type: self.config.drift_type.clone(),
329            version_request: Some(VersionRequest {
330                version,
331                version_type: VersionType::Minor,
332                pre_tag: None,
333                build_tag: None,
334            }),
335            active: false,
336            deactivate_others: false,
337        })
338    }
339}
340
341impl ProfileBaseArgs for CustomDriftProfile {
342    type Config = CustomMetricDriftConfig;
343
344    fn config(&self) -> &Self::Config {
345        &self.config
346    }
347
348    fn get_base_args(&self) -> ProfileArgs {
349        ProfileArgs {
350            name: self.config.name.clone(),
351            space: self.config.space.clone(),
352            version: Some(self.config.version.clone()),
353            schedule: self.config.alert_config.schedule.clone(),
354            scouter_version: self.scouter_version.clone(),
355            drift_type: self.config.drift_type.clone(),
356        }
357    }
358
359    fn to_value(&self) -> Value {
360        serde_json::to_value(self).unwrap()
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use crate::AlertThreshold;
368    use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
369
370    #[test]
371    fn test_drift_config() {
372        let mut drift_config = CustomMetricDriftConfig::new(
373            MISSING,
374            MISSING,
375            "0.1.0",
376            25,
377            CustomMetricAlertConfig::default(),
378            None,
379        )
380        .unwrap();
381        assert_eq!(drift_config.name, "__missing__");
382        assert_eq!(drift_config.space, "__missing__");
383        assert_eq!(drift_config.version, "0.1.0");
384        assert_eq!(
385            drift_config.alert_config.dispatch_config,
386            AlertDispatchConfig::default()
387        );
388
389        let test_slack_dispatch_config = SlackDispatchConfig {
390            channel: "test-channel".to_string(),
391        };
392        let new_alert_config = CustomMetricAlertConfig {
393            schedule: "0 0 * * * *".to_string(),
394            dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
395            ..Default::default()
396        };
397
398        // update
399        drift_config
400            .update_config_args(
401                None,
402                Some("test".to_string()),
403                None,
404                None,
405                Some(new_alert_config),
406            )
407            .unwrap();
408
409        assert_eq!(drift_config.name, "test");
410        assert_eq!(
411            drift_config.alert_config.dispatch_config,
412            AlertDispatchConfig::Slack(test_slack_dispatch_config)
413        );
414        assert_eq!(
415            drift_config.alert_config.schedule,
416            "0 0 * * * *".to_string()
417        );
418    }
419
420    #[test]
421    fn test_custom_drift_profile() {
422        let alert_config = CustomMetricAlertConfig {
423            schedule: "0 0 * * * *".to_string(),
424            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
425                team: "test-team".to_string(),
426                priority: "P5".to_string(),
427            }),
428            ..Default::default()
429        };
430
431        let drift_config =
432            CustomMetricDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
433
434        let custom_metrics = vec![
435            CustomMetric::new("mae", 12.4, AlertThreshold::Above, Some(2.3)).unwrap(),
436            CustomMetric::new("accuracy", 0.85, AlertThreshold::Below, None).unwrap(),
437        ];
438
439        let profile = CustomDriftProfile::new(drift_config, custom_metrics).unwrap();
440        let _: Value =
441            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
442
443        assert_eq!(profile.metrics.len(), 2);
444        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
445        let conditions = profile.config.alert_config.alert_conditions.unwrap();
446        assert_eq!(conditions["mae"].alert_threshold, AlertThreshold::Above);
447        assert_eq!(conditions["mae"].delta, Some(2.3));
448        assert_eq!(
449            conditions["accuracy"].alert_threshold,
450            AlertThreshold::Below
451        );
452        assert_eq!(conditions["accuracy"].delta, None);
453    }
454}