scouter_types/custom/
profile.rs

1#![allow(clippy::useless_conversion)]
2use crate::custom::alert::{CustomMetric, CustomMetricAlertConfig};
3use crate::error::{ProfileError, TypeError};
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::ProfileRequest;
6use crate::{
7    DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
8    ProfileFuncs, DEFAULT_VERSION, MISSING,
9};
10use core::fmt::Debug;
11use pyo3::prelude::*;
12use pyo3::types::PyDict;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::collections::HashMap;
16use std::path::PathBuf;
17
18#[pyclass]
19#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
20pub struct CustomMetricDriftConfig {
21    #[pyo3(get, set)]
22    pub sample_size: usize,
23
24    #[pyo3(get, set)]
25    pub space: String,
26
27    #[pyo3(get, set)]
28    pub name: String,
29
30    #[pyo3(get, set)]
31    pub version: String,
32
33    #[pyo3(get, set)]
34    pub alert_config: CustomMetricAlertConfig,
35
36    #[pyo3(get, set)]
37    #[serde(default = "default_drift_type")]
38    pub drift_type: DriftType,
39}
40
41fn default_drift_type() -> DriftType {
42    DriftType::Custom
43}
44
45impl DispatchDriftConfig for CustomMetricDriftConfig {
46    fn get_drift_args(&self) -> DriftArgs {
47        DriftArgs {
48            name: self.name.clone(),
49            space: self.space.clone(),
50            version: self.version.clone(),
51            dispatch_config: self.alert_config.dispatch_config.clone(),
52        }
53    }
54}
55
56#[pymethods]
57#[allow(clippy::too_many_arguments)]
58impl CustomMetricDriftConfig {
59    #[new]
60    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_size=25, alert_config=CustomMetricAlertConfig::default(), config_path=None))]
61    pub fn new(
62        space: &str,
63        name: &str,
64        version: &str,
65        sample_size: usize,
66        alert_config: CustomMetricAlertConfig,
67        config_path: Option<PathBuf>,
68    ) -> Result<Self, ProfileError> {
69        if let Some(config_path) = config_path {
70            let config = CustomMetricDriftConfig::load_from_json_file(config_path)?;
71            return Ok(config);
72        }
73
74        Ok(Self {
75            sample_size,
76            space: space.to_string(),
77            name: name.to_string(),
78            version: version.to_string(),
79            alert_config,
80            drift_type: DriftType::Custom,
81        })
82    }
83
84    #[staticmethod]
85    pub fn load_from_json_file(path: PathBuf) -> Result<CustomMetricDriftConfig, ProfileError> {
86        // deserialize the string to a struct
87
88        let file = std::fs::read_to_string(&path)?;
89
90        Ok(serde_json::from_str(&file)?)
91    }
92
93    pub fn __str__(&self) -> String {
94        // serialize the struct to a string
95        ProfileFuncs::__str__(self)
96    }
97
98    pub fn model_dump_json(&self) -> String {
99        // serialize the struct to a string
100        ProfileFuncs::__json__(self)
101    }
102
103    #[allow(clippy::too_many_arguments)]
104    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
105    pub fn update_config_args(
106        &mut self,
107        space: Option<String>,
108        name: Option<String>,
109        version: Option<String>,
110        alert_config: Option<CustomMetricAlertConfig>,
111    ) -> Result<(), TypeError> {
112        if name.is_some() {
113            self.name = name.ok_or(TypeError::MissingNameError)?;
114        }
115
116        if space.is_some() {
117            self.space = space.ok_or(TypeError::MissingSpaceError)?;
118        }
119
120        if version.is_some() {
121            self.version = version.ok_or(TypeError::MissingVersionError)?;
122        }
123
124        if alert_config.is_some() {
125            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
126        }
127
128        Ok(())
129    }
130}
131
132#[pyclass]
133#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
134pub struct CustomDriftProfile {
135    #[pyo3(get)]
136    pub config: CustomMetricDriftConfig,
137
138    #[pyo3(get)]
139    pub metrics: HashMap<String, f64>,
140
141    #[pyo3(get)]
142    pub scouter_version: String,
143}
144
145#[pymethods]
146impl CustomDriftProfile {
147    #[new]
148    #[pyo3(signature = (config, metrics, scouter_version=None))]
149    pub fn new(
150        mut config: CustomMetricDriftConfig,
151        metrics: Vec<CustomMetric>,
152        scouter_version: Option<String>,
153    ) -> Result<Self, ProfileError> {
154        if metrics.is_empty() {
155            return Err(TypeError::NoMetricsError.into());
156        }
157
158        config.alert_config.set_alert_conditions(&metrics);
159
160        let metric_vals = metrics.iter().map(|m| (m.name.clone(), m.value)).collect();
161
162        let scouter_version = scouter_version.unwrap_or(env!("CARGO_PKG_VERSION").to_string());
163
164        Ok(Self {
165            config,
166            metrics: metric_vals,
167            scouter_version,
168        })
169    }
170
171    pub fn __str__(&self) -> String {
172        // serialize the struct to a string
173        ProfileFuncs::__str__(self)
174    }
175
176    pub fn model_dump_json(&self) -> String {
177        // serialize the struct to a string
178        ProfileFuncs::__json__(self)
179    }
180
181    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
182        let json_str = serde_json::to_string(&self)?;
183
184        let json_value: Value = serde_json::from_str(&json_str)?;
185
186        // Create a new Python dictionary
187        let dict = PyDict::new(py);
188
189        // Convert JSON to Python dict
190        json_to_pyobject(py, &json_value, &dict)?;
191
192        // Return the Python dictionary
193        Ok(dict.into())
194    }
195
196    // Convert python dict into a drift profile
197    #[pyo3(signature = (path=None))]
198    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
199        Ok(ProfileFuncs::save_to_json(
200            self,
201            path,
202            FileName::CustomDriftProfile.to_str(),
203        )?)
204    }
205
206    #[staticmethod]
207    pub fn model_validate(data: &Bound<'_, PyDict>) -> CustomDriftProfile {
208        let json_value = pyobject_to_json(data).unwrap();
209
210        let string = serde_json::to_string(&json_value).unwrap();
211        serde_json::from_str(&string).expect("Failed to load drift profile")
212    }
213
214    #[staticmethod]
215    pub fn model_validate_json(json_string: String) -> CustomDriftProfile {
216        // deserialize the string to a struct
217        serde_json::from_str(&json_string).expect("Failed to load monitor profile")
218    }
219
220    #[staticmethod]
221    pub fn from_file(path: PathBuf) -> Result<CustomDriftProfile, ProfileError> {
222        let file = std::fs::read_to_string(&path)?;
223
224        Ok(serde_json::from_str(&file)?)
225    }
226
227    #[allow(clippy::too_many_arguments)]
228    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
229    pub fn update_config_args(
230        &mut self,
231        space: Option<String>,
232        name: Option<String>,
233        version: Option<String>,
234        alert_config: Option<CustomMetricAlertConfig>,
235    ) -> Result<(), TypeError> {
236        self.config
237            .update_config_args(space, name, version, alert_config)
238    }
239
240    #[getter]
241    pub fn custom_metrics(&self) -> Result<Vec<CustomMetric>, ProfileError> {
242        let alert_conditions = &self
243            .config
244            .alert_config
245            .alert_conditions
246            .clone()
247            .ok_or(ProfileError::CustomThresholdNotSetError)?;
248
249        Ok(self
250            .metrics
251            .iter()
252            .map(|(name, value)| {
253                // get the alert threshold for the metric
254                let alert = alert_conditions
255                    .get(name)
256                    .ok_or(ProfileError::CustomAlertThresholdNotFound)
257                    .unwrap();
258                CustomMetric::new(
259                    name,
260                    *value,
261                    alert.alert_threshold.clone(),
262                    alert.alert_threshold_value,
263                )
264                .unwrap()
265            })
266            .collect())
267    }
268
269    /// Create a profile request from the profile
270    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
271        Ok(ProfileRequest {
272            space: self.config.space.clone(),
273            profile: self.model_dump_json(),
274            drift_type: self.config.drift_type.clone(),
275        })
276    }
277}
278
279impl ProfileBaseArgs for CustomDriftProfile {
280    fn get_base_args(&self) -> ProfileArgs {
281        ProfileArgs {
282            name: self.config.name.clone(),
283            space: self.config.space.clone(),
284            version: self.config.version.clone(),
285            schedule: self.config.alert_config.schedule.clone(),
286            scouter_version: self.scouter_version.clone(),
287            drift_type: self.config.drift_type.clone(),
288        }
289    }
290
291    fn to_value(&self) -> Value {
292        serde_json::to_value(self).unwrap()
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::custom::alert::AlertThreshold;
300    use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
301
302    #[test]
303    fn test_drift_config() {
304        let mut drift_config = CustomMetricDriftConfig::new(
305            MISSING,
306            MISSING,
307            "0.1.0",
308            25,
309            CustomMetricAlertConfig::default(),
310            None,
311        )
312        .unwrap();
313        assert_eq!(drift_config.name, "__missing__");
314        assert_eq!(drift_config.space, "__missing__");
315        assert_eq!(drift_config.version, "0.1.0");
316        assert_eq!(
317            drift_config.alert_config.dispatch_config,
318            AlertDispatchConfig::default()
319        );
320
321        let test_slack_dispatch_config = SlackDispatchConfig {
322            channel: "test-channel".to_string(),
323        };
324        let new_alert_config = CustomMetricAlertConfig {
325            schedule: "0 0 * * * *".to_string(),
326            dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
327            ..Default::default()
328        };
329
330        // update
331        drift_config
332            .update_config_args(None, Some("test".to_string()), None, Some(new_alert_config))
333            .unwrap();
334
335        assert_eq!(drift_config.name, "test");
336        assert_eq!(
337            drift_config.alert_config.dispatch_config,
338            AlertDispatchConfig::Slack(test_slack_dispatch_config)
339        );
340        assert_eq!(
341            drift_config.alert_config.schedule,
342            "0 0 * * * *".to_string()
343        );
344    }
345
346    #[test]
347    fn test_custom_drift_profile() {
348        let alert_config = CustomMetricAlertConfig {
349            schedule: "0 0 * * * *".to_string(),
350            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
351                team: "test-team".to_string(),
352                priority: "P5".to_string(),
353            }),
354            ..Default::default()
355        };
356
357        let drift_config =
358            CustomMetricDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
359
360        let custom_metrics = vec![
361            CustomMetric::new("mae", 12.4, AlertThreshold::Above, Some(2.3)).unwrap(),
362            CustomMetric::new("accuracy", 0.85, AlertThreshold::Below, None).unwrap(),
363        ];
364
365        let profile = CustomDriftProfile::new(drift_config, custom_metrics, None).unwrap();
366        let _: Value =
367            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
368
369        assert_eq!(profile.metrics.len(), 2);
370        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
371        let conditions = profile.config.alert_config.alert_conditions.unwrap();
372        assert_eq!(conditions["mae"].alert_threshold, AlertThreshold::Above);
373        assert_eq!(conditions["mae"].alert_threshold_value, Some(2.3));
374        assert_eq!(
375            conditions["accuracy"].alert_threshold,
376            AlertThreshold::Below
377        );
378        assert_eq!(conditions["accuracy"].alert_threshold_value, None);
379    }
380}