scouter_types/custom/
profile.rs

1#![allow(clippy::useless_conversion)]
2use crate::custom::alert::{CustomMetric, CustomMetricAlertConfig};
3use crate::error::{ProfileError, TypeError};
4use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version};
5use crate::ProfileRequest;
6use crate::{
7    DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
8    PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
9};
10use core::fmt::Debug;
11use pyo3::prelude::*;
12use pyo3::types::PyDict;
13use scouter_semver::VersionType;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::collections::HashMap;
17use std::path::PathBuf;
18
19#[pyclass]
20#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
21pub struct CustomMetricDriftConfig {
22    #[pyo3(get, set)]
23    pub sample_size: usize,
24
25    #[pyo3(get, set)]
26    pub space: String,
27
28    #[pyo3(get, set)]
29    pub name: String,
30
31    #[pyo3(get, set)]
32    pub version: String,
33
34    #[pyo3(get, set)]
35    pub alert_config: CustomMetricAlertConfig,
36
37    #[pyo3(get, set)]
38    #[serde(default = "default_drift_type")]
39    pub drift_type: DriftType,
40}
41
42fn default_drift_type() -> DriftType {
43    DriftType::Custom
44}
45
46impl DispatchDriftConfig for CustomMetricDriftConfig {
47    fn get_drift_args(&self) -> DriftArgs {
48        DriftArgs {
49            name: self.name.clone(),
50            space: self.space.clone(),
51            version: self.version.clone(),
52            dispatch_config: self.alert_config.dispatch_config.clone(),
53        }
54    }
55}
56
57#[pymethods]
58#[allow(clippy::too_many_arguments)]
59impl CustomMetricDriftConfig {
60    #[new]
61    #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_size=25, alert_config=CustomMetricAlertConfig::default(), config_path=None))]
62    pub fn new(
63        space: &str,
64        name: &str,
65        version: &str,
66        sample_size: usize,
67        alert_config: CustomMetricAlertConfig,
68        config_path: Option<PathBuf>,
69    ) -> Result<Self, ProfileError> {
70        if let Some(config_path) = config_path {
71            let config = CustomMetricDriftConfig::load_from_json_file(config_path)?;
72            return Ok(config);
73        }
74
75        Ok(Self {
76            sample_size,
77            space: space.to_string(),
78            name: name.to_string(),
79            version: version.to_string(),
80            alert_config,
81            drift_type: DriftType::Custom,
82        })
83    }
84
85    #[staticmethod]
86    pub fn load_from_json_file(path: PathBuf) -> Result<CustomMetricDriftConfig, ProfileError> {
87        // deserialize the string to a struct
88
89        let file = std::fs::read_to_string(&path)?;
90
91        Ok(serde_json::from_str(&file)?)
92    }
93
94    pub fn __str__(&self) -> String {
95        // serialize the struct to a string
96        PyHelperFuncs::__str__(self)
97    }
98
99    pub fn model_dump_json(&self) -> String {
100        // serialize the struct to a string
101        PyHelperFuncs::__json__(self)
102    }
103
104    #[allow(clippy::too_many_arguments)]
105    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
106    pub fn update_config_args(
107        &mut self,
108        space: Option<String>,
109        name: Option<String>,
110        version: Option<String>,
111        alert_config: Option<CustomMetricAlertConfig>,
112    ) -> Result<(), TypeError> {
113        if name.is_some() {
114            self.name = name.ok_or(TypeError::MissingNameError)?;
115        }
116
117        if space.is_some() {
118            self.space = space.ok_or(TypeError::MissingSpaceError)?;
119        }
120
121        if version.is_some() {
122            self.version = version.ok_or(TypeError::MissingVersionError)?;
123        }
124
125        if alert_config.is_some() {
126            self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
127        }
128
129        Ok(())
130    }
131}
132
133#[pyclass]
134#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
135pub struct CustomDriftProfile {
136    #[pyo3(get)]
137    pub config: CustomMetricDriftConfig,
138
139    #[pyo3(get)]
140    pub metrics: HashMap<String, f64>,
141
142    #[pyo3(get)]
143    pub scouter_version: String,
144}
145
146impl Default for CustomDriftProfile {
147    fn default() -> Self {
148        Self {
149            config: CustomMetricDriftConfig::new(
150                MISSING,
151                MISSING,
152                DEFAULT_VERSION,
153                25,
154                CustomMetricAlertConfig::default(),
155                None,
156            )
157            .unwrap(),
158            metrics: HashMap::new(),
159            scouter_version: scouter_version(),
160        }
161    }
162}
163
164#[pymethods]
165impl CustomDriftProfile {
166    #[new]
167    #[pyo3(signature = (config, metrics))]
168    pub fn new(
169        mut config: CustomMetricDriftConfig,
170        metrics: Vec<CustomMetric>,
171    ) -> Result<Self, ProfileError> {
172        if metrics.is_empty() {
173            return Err(TypeError::NoMetricsError.into());
174        }
175
176        config.alert_config.set_alert_conditions(&metrics);
177
178        let metric_vals = metrics.iter().map(|m| (m.name.clone(), m.value)).collect();
179
180        Ok(Self {
181            config,
182            metrics: metric_vals,
183            scouter_version: scouter_version(),
184        })
185    }
186
187    pub fn __str__(&self) -> String {
188        // serialize the struct to a string
189        PyHelperFuncs::__str__(self)
190    }
191
192    pub fn model_dump_json(&self) -> String {
193        // serialize the struct to a string
194        PyHelperFuncs::__json__(self)
195    }
196
197    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
198        let json_str = serde_json::to_string(&self)?;
199
200        let json_value: Value = serde_json::from_str(&json_str)?;
201
202        // Create a new Python dictionary
203        let dict = PyDict::new(py);
204
205        // Convert JSON to Python dict
206        json_to_pyobject(py, &json_value, &dict)?;
207
208        // Return the Python dictionary
209        Ok(dict.into())
210    }
211
212    // Convert python dict into a drift profile
213    #[pyo3(signature = (path=None))]
214    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
215        Ok(PyHelperFuncs::save_to_json(
216            self,
217            path,
218            FileName::CustomDriftProfile.to_str(),
219        )?)
220    }
221
222    #[staticmethod]
223    pub fn model_validate(data: &Bound<'_, PyDict>) -> CustomDriftProfile {
224        let json_value = pyobject_to_json(data).unwrap();
225
226        let string = serde_json::to_string(&json_value).unwrap();
227        serde_json::from_str(&string).expect("Failed to load drift profile")
228    }
229
230    #[staticmethod]
231    pub fn model_validate_json(json_string: String) -> CustomDriftProfile {
232        // deserialize the string to a struct
233        serde_json::from_str(&json_string).expect("Failed to load monitor profile")
234    }
235
236    #[staticmethod]
237    pub fn from_file(path: PathBuf) -> Result<CustomDriftProfile, ProfileError> {
238        let file = std::fs::read_to_string(&path)?;
239
240        Ok(serde_json::from_str(&file)?)
241    }
242
243    #[allow(clippy::too_many_arguments)]
244    #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
245    pub fn update_config_args(
246        &mut self,
247        space: Option<String>,
248        name: Option<String>,
249        version: Option<String>,
250        alert_config: Option<CustomMetricAlertConfig>,
251    ) -> Result<(), TypeError> {
252        self.config
253            .update_config_args(space, name, version, alert_config)
254    }
255
256    #[getter]
257    pub fn custom_metrics(&self) -> Result<Vec<CustomMetric>, ProfileError> {
258        let alert_conditions = &self
259            .config
260            .alert_config
261            .alert_conditions
262            .clone()
263            .ok_or(ProfileError::CustomThresholdNotSetError)?;
264
265        Ok(self
266            .metrics
267            .iter()
268            .map(|(name, value)| {
269                // get the alert threshold for the metric
270                let alert = alert_conditions
271                    .get(name)
272                    .ok_or(ProfileError::CustomAlertThresholdNotFound)
273                    .unwrap();
274                CustomMetric::new(
275                    name,
276                    *value,
277                    alert.alert_threshold.clone(),
278                    alert.alert_threshold_value,
279                )
280                .unwrap()
281            })
282            .collect())
283    }
284
285    /// Create a profile request from the profile
286    pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
287        let version: Option<String> = if self.config.version == DEFAULT_VERSION {
288            None
289        } else {
290            Some(self.config.version.clone())
291        };
292
293        Ok(ProfileRequest {
294            space: self.config.space.clone(),
295            profile: self.model_dump_json(),
296            drift_type: self.config.drift_type.clone(),
297            version_request: VersionRequest {
298                version,
299                version_type: VersionType::Minor,
300                pre_tag: None,
301                build_tag: None,
302            },
303            active: false,
304            deactivate_others: false,
305        })
306    }
307}
308
309impl ProfileBaseArgs for CustomDriftProfile {
310    fn get_base_args(&self) -> ProfileArgs {
311        ProfileArgs {
312            name: self.config.name.clone(),
313            space: self.config.space.clone(),
314            version: Some(self.config.version.clone()),
315            schedule: self.config.alert_config.schedule.clone(),
316            scouter_version: self.scouter_version.clone(),
317            drift_type: self.config.drift_type.clone(),
318        }
319    }
320
321    fn to_value(&self) -> Value {
322        serde_json::to_value(self).unwrap()
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use crate::AlertThreshold;
330    use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
331
332    #[test]
333    fn test_drift_config() {
334        let mut drift_config = CustomMetricDriftConfig::new(
335            MISSING,
336            MISSING,
337            "0.1.0",
338            25,
339            CustomMetricAlertConfig::default(),
340            None,
341        )
342        .unwrap();
343        assert_eq!(drift_config.name, "__missing__");
344        assert_eq!(drift_config.space, "__missing__");
345        assert_eq!(drift_config.version, "0.1.0");
346        assert_eq!(
347            drift_config.alert_config.dispatch_config,
348            AlertDispatchConfig::default()
349        );
350
351        let test_slack_dispatch_config = SlackDispatchConfig {
352            channel: "test-channel".to_string(),
353        };
354        let new_alert_config = CustomMetricAlertConfig {
355            schedule: "0 0 * * * *".to_string(),
356            dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
357            ..Default::default()
358        };
359
360        // update
361        drift_config
362            .update_config_args(None, Some("test".to_string()), None, Some(new_alert_config))
363            .unwrap();
364
365        assert_eq!(drift_config.name, "test");
366        assert_eq!(
367            drift_config.alert_config.dispatch_config,
368            AlertDispatchConfig::Slack(test_slack_dispatch_config)
369        );
370        assert_eq!(
371            drift_config.alert_config.schedule,
372            "0 0 * * * *".to_string()
373        );
374    }
375
376    #[test]
377    fn test_custom_drift_profile() {
378        let alert_config = CustomMetricAlertConfig {
379            schedule: "0 0 * * * *".to_string(),
380            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
381                team: "test-team".to_string(),
382                priority: "P5".to_string(),
383            }),
384            ..Default::default()
385        };
386
387        let drift_config =
388            CustomMetricDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
389
390        let custom_metrics = vec![
391            CustomMetric::new("mae", 12.4, AlertThreshold::Above, Some(2.3)).unwrap(),
392            CustomMetric::new("accuracy", 0.85, AlertThreshold::Below, None).unwrap(),
393        ];
394
395        let profile = CustomDriftProfile::new(drift_config, custom_metrics).unwrap();
396        let _: Value =
397            serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
398
399        assert_eq!(profile.metrics.len(), 2);
400        assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
401        let conditions = profile.config.alert_config.alert_conditions.unwrap();
402        assert_eq!(conditions["mae"].alert_threshold, AlertThreshold::Above);
403        assert_eq!(conditions["mae"].alert_threshold_value, Some(2.3));
404        assert_eq!(
405            conditions["accuracy"].alert_threshold,
406            AlertThreshold::Below
407        );
408        assert_eq!(conditions["accuracy"].alert_threshold_value, None);
409    }
410}