Skip to main content

scouter_types/custom/
profile.rs

1#![allow(clippy::useless_conversion)]
2use crate::custom::alert::{CustomMetric, CustomMetricAlertConfig};
3use crate::error::{ProfileError, TypeError};
4use crate::traits::ConfigExt;
5use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version};
6use crate::ProfileRequest;
7use crate::{
8    DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
9    PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
10};
11use core::fmt::Debug;
12use potato_head::create_uuid7;
13use pyo3::prelude::*;
14use pyo3::types::PyDict;
15use scouter_semver::VersionType;
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::path::PathBuf;
20
21#[pyclass]
22#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
23pub struct CustomMetricDriftConfig {
24    #[pyo3(get, set)]
25    pub sample_size: usize,
26
27    #[pyo3(get, set)]
28    pub space: String,
29
30    #[pyo3(get, set)]
31    pub name: String,
32
33    #[pyo3(get, set)]
34    pub version: String,
35
36    #[pyo3(get, set)]
37    pub uid: String,
38
39    #[pyo3(get, set)]
40    pub alert_config: CustomMetricAlertConfig,
41
42    #[pyo3(get, set)]
43    #[serde(default = "default_drift_type")]
44    pub drift_type: DriftType,
45}
46
47impl ConfigExt for CustomMetricDriftConfig {
48    fn space(&self) -> &str {
49        &self.space
50    }
51
52    fn name(&self) -> &str {
53        &self.name
54    }
55
56    fn version(&self) -> &str {
57        &self.version
58    }
59
60    fn uid(&self) -> &str {
61        &self.uid
62    }
63}
64
65fn default_drift_type() -> DriftType {
66    DriftType::Custom
67}
68
69impl DispatchDriftConfig for CustomMetricDriftConfig {
70    fn get_drift_args(&self) -> DriftArgs {
71        DriftArgs {
72            name: self.name.clone(),
73            space: self.space.clone(),
74            version: self.version.clone(),
75            dispatch_config: self.alert_config.dispatch_config.clone(),
76        }
77    }
78}
79
80#[pymethods]
81#[allow(clippy::too_many_arguments)]
82impl CustomMetricDriftConfig {
83    #[new]
84    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_size=25, alert_config=CustomMetricAlertConfig::default(), config_path=None))]
85    pub fn new(
86        space: &str,
87        name: &str,
88        version: &str,
89        sample_size: usize,
90        alert_config: CustomMetricAlertConfig,
91        config_path: Option<PathBuf>,
92    ) -> Result<Self, ProfileError> {
93        if let Some(config_path) = config_path {
94            let config = CustomMetricDriftConfig::load_from_json_file(config_path)?;
95            return Ok(config);
96        }
97
98        Ok(Self {
99            sample_size,
100            space: space.to_string(),
101            name: name.to_string(),
102            version: version.to_string(),
103            uid: create_uuid7(),
104            alert_config,
105            drift_type: DriftType::Custom,
106        })
107    }
108
109    #[staticmethod]
110    pub fn load_from_json_file(path: PathBuf) -> Result<CustomMetricDriftConfig, ProfileError> {
111        // deserialize the string to a struct
112
113        let file = std::fs::read_to_string(&path)?;
114
115        Ok(serde_json::from_str(&file)?)
116    }
117
118    pub fn __str__(&self) -> String {
119        // serialize the struct to a string
120        PyHelperFuncs::__str__(self)
121    }
122
123    pub fn model_dump_json(&self) -> String {
124        // serialize the struct to a string
125        PyHelperFuncs::__json__(self)
126    }
127
128    #[allow(clippy::too_many_arguments)]
129    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
130    pub fn update_config_args(
131        &mut self,
132        space: Option<String>,
133        name: Option<String>,
134        version: Option<String>,
135        uid: Option<String>,
136        alert_config: Option<CustomMetricAlertConfig>,
137    ) -> Result<(), TypeError> {
138        if name.is_some() {
139            self.name = name.ok_or(TypeError::MissingNameError)?;
140        }
141
142        if space.is_some() {
143            self.space = space.ok_or(TypeError::MissingSpaceError)?;
144        }
145
146        if version.is_some() {
147            self.version = version.ok_or(TypeError::MissingVersionError)?;
148        }
149
150        if alert_config.is_some() {
151            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
152        }
153
154        if uid.is_some() {
155            self.uid = uid.ok_or(TypeError::MissingUidError)?;
156        }
157
158        Ok(())
159    }
160}
161
162#[pyclass]
163#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
164pub struct CustomDriftProfile {
165    #[pyo3(get)]
166    pub config: CustomMetricDriftConfig,
167
168    #[pyo3(get)]
169    pub metrics: HashMap<String, f64>,
170
171    #[pyo3(get)]
172    pub scouter_version: String,
173}
174
175impl Default for CustomDriftProfile {
176    fn default() -> Self {
177        Self {
178            config: CustomMetricDriftConfig::new(
179                MISSING,
180                MISSING,
181                DEFAULT_VERSION,
182                25,
183                CustomMetricAlertConfig::default(),
184                None,
185            )
186            .unwrap(),
187            metrics: HashMap::new(),
188            scouter_version: scouter_version(),
189        }
190    }
191}
192
193#[pymethods]
194impl CustomDriftProfile {
195    #[new]
196    #[pyo3(signature = (config, metrics))]
197    pub fn new(
198        mut config: CustomMetricDriftConfig,
199        metrics: Vec<CustomMetric>,
200    ) -> Result<Self, ProfileError> {
201        if metrics.is_empty() {
202            return Err(TypeError::NoMetricsError.into());
203        }
204
205        config.alert_config.set_alert_conditions(&metrics);
206
207        let metric_vals = metrics
208            .iter()
209            .map(|m| (m.name.clone(), m.baseline_value))
210            .collect();
211
212        Ok(Self {
213            config,
214            metrics: metric_vals,
215            scouter_version: scouter_version(),
216        })
217    }
218
219    pub fn __str__(&self) -> String {
220        // serialize the struct to a string
221        PyHelperFuncs::__str__(self)
222    }
223
224    pub fn model_dump_json(&self) -> String {
225        // serialize the struct to a string
226        PyHelperFuncs::__json__(self)
227    }
228
229    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
230        let json_str = serde_json::to_string(&self)?;
231
232        let json_value: Value = serde_json::from_str(&json_str)?;
233
234        // Create a new Python dictionary
235        let dict = PyDict::new(py);
236
237        // Convert JSON to Python dict
238        json_to_pyobject(py, &json_value, &dict)?;
239
240        // Return the Python dictionary
241        Ok(dict.into())
242    }
243
244    #[getter]
245    pub fn drift_type(&self) -> DriftType {
246        self.config.drift_type.clone()
247    }
248
249    // Convert python dict into a drift profile
250    #[pyo3(signature = (path=None))]
251    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
252        Ok(PyHelperFuncs::save_to_json(
253            self,
254            path,
255            FileName::CustomDriftProfile.to_str(),
256        )?)
257    }
258
259    #[getter]
260    pub fn uid(&self) -> String {
261        self.config.uid.clone()
262    }
263
264    #[setter]
265    pub fn set_uid(&mut self, uid: String) {
266        self.config.uid = uid;
267    }
268
269    #[staticmethod]
270    pub fn model_validate(data: &Bound<'_, PyDict>) -> CustomDriftProfile {
271        let json_value = pyobject_to_json(data).unwrap();
272
273        let string = serde_json::to_string(&json_value).unwrap();
274        serde_json::from_str(&string).expect("Failed to load drift profile")
275    }
276
277    #[staticmethod]
278    pub fn model_validate_json(json_string: String) -> CustomDriftProfile {
279        // deserialize the string to a struct
280        serde_json::from_str(&json_string).expect("Failed to load monitor profile")
281    }
282
283    #[staticmethod]
284    pub fn from_file(path: PathBuf) -> Result<CustomDriftProfile, ProfileError> {
285        let file = std::fs::read_to_string(&path)?;
286
287        Ok(serde_json::from_str(&file)?)
288    }
289
290    #[allow(clippy::too_many_arguments)]
291    #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
292    pub fn update_config_args(
293        &mut self,
294        space: Option<String>,
295        name: Option<String>,
296        version: Option<String>,
297        uid: Option<String>,
298        alert_config: Option<CustomMetricAlertConfig>,
299    ) -> Result<(), TypeError> {
300        self.config
301            .update_config_args(space, name, version, uid, alert_config)
302    }
303
304    #[getter]
305    pub fn custom_metrics(&self) -> Result<Vec<CustomMetric>, ProfileError> {
306        let alert_conditions = &self
307            .config
308            .alert_config
309            .alert_conditions
310            .clone()
311            .ok_or(ProfileError::CustomThresholdNotSetError)?;
312
313        Ok(self
314            .metrics
315            .iter()
316            .map(|(name, value)| {
317                // get the alert threshold for the metric
318                let alert = alert_conditions
319                    .get(name)
320                    .ok_or(ProfileError::CustomAlertThresholdNotFound)
321                    .unwrap();
322                CustomMetric::new(name, *value, alert.alert_threshold.clone(), alert.delta).unwrap()
323            })
324            .collect())
325    }
326
327    /// Create a profile request from the profile
328    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
329        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
330            None
331        } else {
332            Some(self.config.version.clone())
333        };
334
335        Ok(ProfileRequest {
336            space: self.config.space.clone(),
337            profile: self.model_dump_json(),
338            drift_type: self.config.drift_type.clone(),
339            version_request: Some(VersionRequest {
340                version,
341                version_type: VersionType::Minor,
342                pre_tag: None,
343                build_tag: None,
344            }),
345            active: false,
346            deactivate_others: false,
347        })
348    }
349}
350
351impl ProfileBaseArgs for CustomDriftProfile {
352    type Config = CustomMetricDriftConfig;
353
354    fn config(&self) -> &Self::Config {
355        &self.config
356    }
357
358    fn get_base_args(&self) -> ProfileArgs {
359        ProfileArgs {
360            name: self.config.name.clone(),
361            space: self.config.space.clone(),
362            version: Some(self.config.version.clone()),
363            schedule: self.config.alert_config.schedule.clone(),
364            scouter_version: self.scouter_version.clone(),
365            drift_type: self.config.drift_type.clone(),
366        }
367    }
368
369    fn to_value(&self) -> Value {
370        serde_json::to_value(self).unwrap()
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::AlertThreshold;
378    use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
379
380    #[test]
381    fn test_drift_config() {
382        let mut drift_config = CustomMetricDriftConfig::new(
383            MISSING,
384            MISSING,
385            "0.1.0",
386            25,
387            CustomMetricAlertConfig::default(),
388            None,
389        )
390        .unwrap();
391        assert_eq!(drift_config.name, "__missing__");
392        assert_eq!(drift_config.space, "__missing__");
393        assert_eq!(drift_config.version, "0.1.0");
394        assert_eq!(
395            drift_config.alert_config.dispatch_config,
396            AlertDispatchConfig::default()
397        );
398
399        let test_slack_dispatch_config = SlackDispatchConfig {
400            channel: "test-channel".to_string(),
401        };
402        let new_alert_config = CustomMetricAlertConfig {
403            schedule: "0 0 * * * *".to_string(),
404            dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
405            ..Default::default()
406        };
407
408        // update
409        drift_config
410            .update_config_args(
411                None,
412                Some("test".to_string()),
413                None,
414                None,
415                Some(new_alert_config),
416            )
417            .unwrap();
418
419        assert_eq!(drift_config.name, "test");
420        assert_eq!(
421            drift_config.alert_config.dispatch_config,
422            AlertDispatchConfig::Slack(test_slack_dispatch_config)
423        );
424        assert_eq!(
425            drift_config.alert_config.schedule,
426            "0 0 * * * *".to_string()
427        );
428    }
429
430    #[test]
431    fn test_custom_drift_profile() {
432        let alert_config = CustomMetricAlertConfig {
433            schedule: "0 0 * * * *".to_string(),
434            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
435                team: "test-team".to_string(),
436                priority: "P5".to_string(),
437            }),
438            ..Default::default()
439        };
440
441        let drift_config =
442            CustomMetricDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
443
444        let custom_metrics = vec![
445            CustomMetric::new("mae", 12.4, AlertThreshold::Above, Some(2.3)).unwrap(),
446            CustomMetric::new("accuracy", 0.85, AlertThreshold::Below, None).unwrap(),
447        ];
448
449        let profile = CustomDriftProfile::new(drift_config, custom_metrics).unwrap();
450        let _: Value =
451            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
452
453        assert_eq!(profile.metrics.len(), 2);
454        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
455        let conditions = profile.config.alert_config.alert_conditions.unwrap();
456        assert_eq!(conditions["mae"].alert_threshold, AlertThreshold::Above);
457        assert_eq!(conditions["mae"].delta, Some(2.3));
458        assert_eq!(
459            conditions["accuracy"].alert_threshold,
460            AlertThreshold::Below
461        );
462        assert_eq!(conditions["accuracy"].delta, None);
463    }
464}