1#![allow(clippy::useless_conversion)]
2use crate::custom::alert::{CustomMetric, CustomMetricAlertConfig};
3use crate::error::{ProfileError, TypeError};
4use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version};
5use crate::{ConfigExt, ProfileRequest};
6use crate::{
7 DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
8 PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
9};
10use core::fmt::Debug;
11use potato_head::create_uuid7;
12use pyo3::prelude::*;
13use pyo3::types::PyDict;
14use scouter_semver::VersionType;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashMap;
18use std::path::PathBuf;
19
20#[pyclass]
21#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
22pub struct CustomMetricDriftConfig {
23 #[pyo3(get, set)]
24 pub sample_size: usize,
25
26 #[pyo3(get, set)]
27 pub space: String,
28
29 #[pyo3(get, set)]
30 pub name: String,
31
32 #[pyo3(get, set)]
33 pub version: String,
34
35 #[pyo3(get, set)]
36 pub uid: String,
37
38 #[pyo3(get, set)]
39 pub alert_config: CustomMetricAlertConfig,
40
41 #[pyo3(get, set)]
42 #[serde(default = "default_drift_type")]
43 pub drift_type: DriftType,
44}
45
46impl ConfigExt for CustomMetricDriftConfig {
47 fn space(&self) -> &str {
48 &self.space
49 }
50
51 fn name(&self) -> &str {
52 &self.name
53 }
54
55 fn version(&self) -> &str {
56 &self.version
57 }
58}
59
60fn default_drift_type() -> DriftType {
61 DriftType::Custom
62}
63
64impl DispatchDriftConfig for CustomMetricDriftConfig {
65 fn get_drift_args(&self) -> DriftArgs {
66 DriftArgs {
67 name: self.name.clone(),
68 space: self.space.clone(),
69 version: self.version.clone(),
70 dispatch_config: self.alert_config.dispatch_config.clone(),
71 }
72 }
73}
74
75#[pymethods]
76#[allow(clippy::too_many_arguments)]
77impl CustomMetricDriftConfig {
78 #[new]
79 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_size=25, alert_config=CustomMetricAlertConfig::default(), config_path=None))]
80 pub fn new(
81 space: &str,
82 name: &str,
83 version: &str,
84 sample_size: usize,
85 alert_config: CustomMetricAlertConfig,
86 config_path: Option<PathBuf>,
87 ) -> Result<Self, ProfileError> {
88 if let Some(config_path) = config_path {
89 let config = CustomMetricDriftConfig::load_from_json_file(config_path)?;
90 return Ok(config);
91 }
92
93 Ok(Self {
94 sample_size,
95 space: space.to_string(),
96 name: name.to_string(),
97 version: version.to_string(),
98 uid: create_uuid7(),
99 alert_config,
100 drift_type: DriftType::Custom,
101 })
102 }
103
104 #[staticmethod]
105 pub fn load_from_json_file(path: PathBuf) -> Result<CustomMetricDriftConfig, ProfileError> {
106 let file = std::fs::read_to_string(&path)?;
109
110 Ok(serde_json::from_str(&file)?)
111 }
112
113 pub fn __str__(&self) -> String {
114 PyHelperFuncs::__str__(self)
116 }
117
118 pub fn model_dump_json(&self) -> String {
119 PyHelperFuncs::__json__(self)
121 }
122
123 #[allow(clippy::too_many_arguments)]
124 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
125 pub fn update_config_args(
126 &mut self,
127 space: Option<String>,
128 name: Option<String>,
129 version: Option<String>,
130 uid: Option<String>,
131 alert_config: Option<CustomMetricAlertConfig>,
132 ) -> Result<(), TypeError> {
133 if name.is_some() {
134 self.name = name.ok_or(TypeError::MissingNameError)?;
135 }
136
137 if space.is_some() {
138 self.space = space.ok_or(TypeError::MissingSpaceError)?;
139 }
140
141 if version.is_some() {
142 self.version = version.ok_or(TypeError::MissingVersionError)?;
143 }
144
145 if alert_config.is_some() {
146 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
147 }
148
149 if uid.is_some() {
150 self.uid = uid.ok_or(TypeError::MissingUidError)?;
151 }
152
153 Ok(())
154 }
155}
156
157#[pyclass]
158#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
159pub struct CustomDriftProfile {
160 #[pyo3(get)]
161 pub config: CustomMetricDriftConfig,
162
163 #[pyo3(get)]
164 pub metrics: HashMap<String, f64>,
165
166 #[pyo3(get)]
167 pub scouter_version: String,
168}
169
170impl Default for CustomDriftProfile {
171 fn default() -> Self {
172 Self {
173 config: CustomMetricDriftConfig::new(
174 MISSING,
175 MISSING,
176 DEFAULT_VERSION,
177 25,
178 CustomMetricAlertConfig::default(),
179 None,
180 )
181 .unwrap(),
182 metrics: HashMap::new(),
183 scouter_version: scouter_version(),
184 }
185 }
186}
187
188#[pymethods]
189impl CustomDriftProfile {
190 #[new]
191 #[pyo3(signature = (config, metrics))]
192 pub fn new(
193 mut config: CustomMetricDriftConfig,
194 metrics: Vec<CustomMetric>,
195 ) -> Result<Self, ProfileError> {
196 if metrics.is_empty() {
197 return Err(TypeError::NoMetricsError.into());
198 }
199
200 config.alert_config.set_alert_conditions(&metrics);
201
202 let metric_vals = metrics
203 .iter()
204 .map(|m| (m.name.clone(), m.baseline_value))
205 .collect();
206
207 Ok(Self {
208 config,
209 metrics: metric_vals,
210 scouter_version: scouter_version(),
211 })
212 }
213
214 pub fn __str__(&self) -> String {
215 PyHelperFuncs::__str__(self)
217 }
218
219 pub fn model_dump_json(&self) -> String {
220 PyHelperFuncs::__json__(self)
222 }
223
224 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
225 let json_str = serde_json::to_string(&self)?;
226
227 let json_value: Value = serde_json::from_str(&json_str)?;
228
229 let dict = PyDict::new(py);
231
232 json_to_pyobject(py, &json_value, &dict)?;
234
235 Ok(dict.into())
237 }
238
239 #[pyo3(signature = (path=None))]
241 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
242 Ok(PyHelperFuncs::save_to_json(
243 self,
244 path,
245 FileName::CustomDriftProfile.to_str(),
246 )?)
247 }
248
249 #[getter]
250 pub fn uid(&self) -> String {
251 self.config.uid.clone()
252 }
253
254 #[setter]
255 pub fn set_uid(&mut self, uid: String) {
256 self.config.uid = uid;
257 }
258
259 #[staticmethod]
260 pub fn model_validate(data: &Bound<'_, PyDict>) -> CustomDriftProfile {
261 let json_value = pyobject_to_json(data).unwrap();
262
263 let string = serde_json::to_string(&json_value).unwrap();
264 serde_json::from_str(&string).expect("Failed to load drift profile")
265 }
266
267 #[staticmethod]
268 pub fn model_validate_json(json_string: String) -> CustomDriftProfile {
269 serde_json::from_str(&json_string).expect("Failed to load monitor profile")
271 }
272
273 #[staticmethod]
274 pub fn from_file(path: PathBuf) -> Result<CustomDriftProfile, ProfileError> {
275 let file = std::fs::read_to_string(&path)?;
276
277 Ok(serde_json::from_str(&file)?)
278 }
279
280 #[allow(clippy::too_many_arguments)]
281 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
282 pub fn update_config_args(
283 &mut self,
284 space: Option<String>,
285 name: Option<String>,
286 version: Option<String>,
287 uid: Option<String>,
288 alert_config: Option<CustomMetricAlertConfig>,
289 ) -> Result<(), TypeError> {
290 self.config
291 .update_config_args(space, name, version, uid, alert_config)
292 }
293
294 #[getter]
295 pub fn custom_metrics(&self) -> Result<Vec<CustomMetric>, ProfileError> {
296 let alert_conditions = &self
297 .config
298 .alert_config
299 .alert_conditions
300 .clone()
301 .ok_or(ProfileError::CustomThresholdNotSetError)?;
302
303 Ok(self
304 .metrics
305 .iter()
306 .map(|(name, value)| {
307 let alert = alert_conditions
309 .get(name)
310 .ok_or(ProfileError::CustomAlertThresholdNotFound)
311 .unwrap();
312 CustomMetric::new(name, *value, alert.alert_threshold.clone(), alert.delta).unwrap()
313 })
314 .collect())
315 }
316
317 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
319 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
320 None
321 } else {
322 Some(self.config.version.clone())
323 };
324
325 Ok(ProfileRequest {
326 space: self.config.space.clone(),
327 profile: self.model_dump_json(),
328 drift_type: self.config.drift_type.clone(),
329 version_request: Some(VersionRequest {
330 version,
331 version_type: VersionType::Minor,
332 pre_tag: None,
333 build_tag: None,
334 }),
335 active: false,
336 deactivate_others: false,
337 })
338 }
339}
340
341impl ProfileBaseArgs for CustomDriftProfile {
342 type Config = CustomMetricDriftConfig;
343
344 fn config(&self) -> &Self::Config {
345 &self.config
346 }
347
348 fn get_base_args(&self) -> ProfileArgs {
349 ProfileArgs {
350 name: self.config.name.clone(),
351 space: self.config.space.clone(),
352 version: Some(self.config.version.clone()),
353 schedule: self.config.alert_config.schedule.clone(),
354 scouter_version: self.scouter_version.clone(),
355 drift_type: self.config.drift_type.clone(),
356 }
357 }
358
359 fn to_value(&self) -> Value {
360 serde_json::to_value(self).unwrap()
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use crate::AlertThreshold;
368 use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
369
370 #[test]
371 fn test_drift_config() {
372 let mut drift_config = CustomMetricDriftConfig::new(
373 MISSING,
374 MISSING,
375 "0.1.0",
376 25,
377 CustomMetricAlertConfig::default(),
378 None,
379 )
380 .unwrap();
381 assert_eq!(drift_config.name, "__missing__");
382 assert_eq!(drift_config.space, "__missing__");
383 assert_eq!(drift_config.version, "0.1.0");
384 assert_eq!(
385 drift_config.alert_config.dispatch_config,
386 AlertDispatchConfig::default()
387 );
388
389 let test_slack_dispatch_config = SlackDispatchConfig {
390 channel: "test-channel".to_string(),
391 };
392 let new_alert_config = CustomMetricAlertConfig {
393 schedule: "0 0 * * * *".to_string(),
394 dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
395 ..Default::default()
396 };
397
398 drift_config
400 .update_config_args(
401 None,
402 Some("test".to_string()),
403 None,
404 None,
405 Some(new_alert_config),
406 )
407 .unwrap();
408
409 assert_eq!(drift_config.name, "test");
410 assert_eq!(
411 drift_config.alert_config.dispatch_config,
412 AlertDispatchConfig::Slack(test_slack_dispatch_config)
413 );
414 assert_eq!(
415 drift_config.alert_config.schedule,
416 "0 0 * * * *".to_string()
417 );
418 }
419
420 #[test]
421 fn test_custom_drift_profile() {
422 let alert_config = CustomMetricAlertConfig {
423 schedule: "0 0 * * * *".to_string(),
424 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
425 team: "test-team".to_string(),
426 priority: "P5".to_string(),
427 }),
428 ..Default::default()
429 };
430
431 let drift_config =
432 CustomMetricDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
433
434 let custom_metrics = vec![
435 CustomMetric::new("mae", 12.4, AlertThreshold::Above, Some(2.3)).unwrap(),
436 CustomMetric::new("accuracy", 0.85, AlertThreshold::Below, None).unwrap(),
437 ];
438
439 let profile = CustomDriftProfile::new(drift_config, custom_metrics).unwrap();
440 let _: Value =
441 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
442
443 assert_eq!(profile.metrics.len(), 2);
444 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
445 let conditions = profile.config.alert_config.alert_conditions.unwrap();
446 assert_eq!(conditions["mae"].alert_threshold, AlertThreshold::Above);
447 assert_eq!(conditions["mae"].delta, Some(2.3));
448 assert_eq!(
449 conditions["accuracy"].alert_threshold,
450 AlertThreshold::Below
451 );
452 assert_eq!(conditions["accuracy"].delta, None);
453 }
454}