1#![allow(clippy::useless_conversion)]
2use crate::custom::alert::{CustomMetric, CustomMetricAlertConfig};
3use crate::error::{ProfileError, TypeError};
4use crate::traits::ConfigExt;
5use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version};
6use crate::ProfileRequest;
7use crate::{
8 DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
9 PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
10};
11use core::fmt::Debug;
12use potato_head::create_uuid7;
13use pyo3::prelude::*;
14use pyo3::types::PyDict;
15use scouter_semver::VersionType;
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19use std::path::PathBuf;
20
21#[pyclass]
22#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
23pub struct CustomMetricDriftConfig {
24 #[pyo3(get, set)]
25 pub sample_size: usize,
26
27 #[pyo3(get, set)]
28 pub space: String,
29
30 #[pyo3(get, set)]
31 pub name: String,
32
33 #[pyo3(get, set)]
34 pub version: String,
35
36 #[pyo3(get, set)]
37 pub uid: String,
38
39 #[pyo3(get, set)]
40 pub alert_config: CustomMetricAlertConfig,
41
42 #[pyo3(get, set)]
43 #[serde(default = "default_drift_type")]
44 pub drift_type: DriftType,
45}
46
47impl ConfigExt for CustomMetricDriftConfig {
48 fn space(&self) -> &str {
49 &self.space
50 }
51
52 fn name(&self) -> &str {
53 &self.name
54 }
55
56 fn version(&self) -> &str {
57 &self.version
58 }
59
60 fn uid(&self) -> &str {
61 &self.uid
62 }
63}
64
65fn default_drift_type() -> DriftType {
66 DriftType::Custom
67}
68
69impl DispatchDriftConfig for CustomMetricDriftConfig {
70 fn get_drift_args(&self) -> DriftArgs {
71 DriftArgs {
72 name: self.name.clone(),
73 space: self.space.clone(),
74 version: self.version.clone(),
75 dispatch_config: self.alert_config.dispatch_config.clone(),
76 }
77 }
78}
79
80#[pymethods]
81#[allow(clippy::too_many_arguments)]
82impl CustomMetricDriftConfig {
83 #[new]
84 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_size=25, alert_config=CustomMetricAlertConfig::default(), config_path=None))]
85 pub fn new(
86 space: &str,
87 name: &str,
88 version: &str,
89 sample_size: usize,
90 alert_config: CustomMetricAlertConfig,
91 config_path: Option<PathBuf>,
92 ) -> Result<Self, ProfileError> {
93 if let Some(config_path) = config_path {
94 let config = CustomMetricDriftConfig::load_from_json_file(config_path)?;
95 return Ok(config);
96 }
97
98 Ok(Self {
99 sample_size,
100 space: space.to_string(),
101 name: name.to_string(),
102 version: version.to_string(),
103 uid: create_uuid7(),
104 alert_config,
105 drift_type: DriftType::Custom,
106 })
107 }
108
109 #[staticmethod]
110 pub fn load_from_json_file(path: PathBuf) -> Result<CustomMetricDriftConfig, ProfileError> {
111 let file = std::fs::read_to_string(&path)?;
114
115 Ok(serde_json::from_str(&file)?)
116 }
117
118 pub fn __str__(&self) -> String {
119 PyHelperFuncs::__str__(self)
121 }
122
123 pub fn model_dump_json(&self) -> String {
124 PyHelperFuncs::__json__(self)
126 }
127
128 #[allow(clippy::too_many_arguments)]
129 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
130 pub fn update_config_args(
131 &mut self,
132 space: Option<String>,
133 name: Option<String>,
134 version: Option<String>,
135 uid: Option<String>,
136 alert_config: Option<CustomMetricAlertConfig>,
137 ) -> Result<(), TypeError> {
138 if name.is_some() {
139 self.name = name.ok_or(TypeError::MissingNameError)?;
140 }
141
142 if space.is_some() {
143 self.space = space.ok_or(TypeError::MissingSpaceError)?;
144 }
145
146 if version.is_some() {
147 self.version = version.ok_or(TypeError::MissingVersionError)?;
148 }
149
150 if alert_config.is_some() {
151 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
152 }
153
154 if uid.is_some() {
155 self.uid = uid.ok_or(TypeError::MissingUidError)?;
156 }
157
158 Ok(())
159 }
160}
161
162#[pyclass]
163#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
164pub struct CustomDriftProfile {
165 #[pyo3(get)]
166 pub config: CustomMetricDriftConfig,
167
168 #[pyo3(get)]
169 pub metrics: HashMap<String, f64>,
170
171 #[pyo3(get)]
172 pub scouter_version: String,
173}
174
175impl Default for CustomDriftProfile {
176 fn default() -> Self {
177 Self {
178 config: CustomMetricDriftConfig::new(
179 MISSING,
180 MISSING,
181 DEFAULT_VERSION,
182 25,
183 CustomMetricAlertConfig::default(),
184 None,
185 )
186 .unwrap(),
187 metrics: HashMap::new(),
188 scouter_version: scouter_version(),
189 }
190 }
191}
192
193#[pymethods]
194impl CustomDriftProfile {
195 #[new]
196 #[pyo3(signature = (config, metrics))]
197 pub fn new(
198 mut config: CustomMetricDriftConfig,
199 metrics: Vec<CustomMetric>,
200 ) -> Result<Self, ProfileError> {
201 if metrics.is_empty() {
202 return Err(TypeError::NoMetricsError.into());
203 }
204
205 config.alert_config.set_alert_conditions(&metrics);
206
207 let metric_vals = metrics
208 .iter()
209 .map(|m| (m.name.clone(), m.baseline_value))
210 .collect();
211
212 Ok(Self {
213 config,
214 metrics: metric_vals,
215 scouter_version: scouter_version(),
216 })
217 }
218
219 pub fn __str__(&self) -> String {
220 PyHelperFuncs::__str__(self)
222 }
223
224 pub fn model_dump_json(&self) -> String {
225 PyHelperFuncs::__json__(self)
227 }
228
229 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
230 let json_str = serde_json::to_string(&self)?;
231
232 let json_value: Value = serde_json::from_str(&json_str)?;
233
234 let dict = PyDict::new(py);
236
237 json_to_pyobject(py, &json_value, &dict)?;
239
240 Ok(dict.into())
242 }
243
244 #[getter]
245 pub fn drift_type(&self) -> DriftType {
246 self.config.drift_type.clone()
247 }
248
249 #[pyo3(signature = (path=None))]
251 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
252 Ok(PyHelperFuncs::save_to_json(
253 self,
254 path,
255 FileName::CustomDriftProfile.to_str(),
256 )?)
257 }
258
259 #[getter]
260 pub fn uid(&self) -> String {
261 self.config.uid.clone()
262 }
263
264 #[setter]
265 pub fn set_uid(&mut self, uid: String) {
266 self.config.uid = uid;
267 }
268
269 #[staticmethod]
270 pub fn model_validate(data: &Bound<'_, PyDict>) -> CustomDriftProfile {
271 let json_value = pyobject_to_json(data).unwrap();
272
273 let string = serde_json::to_string(&json_value).unwrap();
274 serde_json::from_str(&string).expect("Failed to load drift profile")
275 }
276
277 #[staticmethod]
278 pub fn model_validate_json(json_string: String) -> CustomDriftProfile {
279 serde_json::from_str(&json_string).expect("Failed to load monitor profile")
281 }
282
283 #[staticmethod]
284 pub fn from_file(path: PathBuf) -> Result<CustomDriftProfile, ProfileError> {
285 let file = std::fs::read_to_string(&path)?;
286
287 Ok(serde_json::from_str(&file)?)
288 }
289
290 #[allow(clippy::too_many_arguments)]
291 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None))]
292 pub fn update_config_args(
293 &mut self,
294 space: Option<String>,
295 name: Option<String>,
296 version: Option<String>,
297 uid: Option<String>,
298 alert_config: Option<CustomMetricAlertConfig>,
299 ) -> Result<(), TypeError> {
300 self.config
301 .update_config_args(space, name, version, uid, alert_config)
302 }
303
304 #[getter]
305 pub fn custom_metrics(&self) -> Result<Vec<CustomMetric>, ProfileError> {
306 let alert_conditions = &self
307 .config
308 .alert_config
309 .alert_conditions
310 .clone()
311 .ok_or(ProfileError::CustomThresholdNotSetError)?;
312
313 Ok(self
314 .metrics
315 .iter()
316 .map(|(name, value)| {
317 let alert = alert_conditions
319 .get(name)
320 .ok_or(ProfileError::CustomAlertThresholdNotFound)
321 .unwrap();
322 CustomMetric::new(name, *value, alert.alert_threshold.clone(), alert.delta).unwrap()
323 })
324 .collect())
325 }
326
327 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
329 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
330 None
331 } else {
332 Some(self.config.version.clone())
333 };
334
335 Ok(ProfileRequest {
336 space: self.config.space.clone(),
337 profile: self.model_dump_json(),
338 drift_type: self.config.drift_type.clone(),
339 version_request: Some(VersionRequest {
340 version,
341 version_type: VersionType::Minor,
342 pre_tag: None,
343 build_tag: None,
344 }),
345 active: false,
346 deactivate_others: false,
347 })
348 }
349}
350
351impl ProfileBaseArgs for CustomDriftProfile {
352 type Config = CustomMetricDriftConfig;
353
354 fn config(&self) -> &Self::Config {
355 &self.config
356 }
357
358 fn get_base_args(&self) -> ProfileArgs {
359 ProfileArgs {
360 name: self.config.name.clone(),
361 space: self.config.space.clone(),
362 version: Some(self.config.version.clone()),
363 schedule: self.config.alert_config.schedule.clone(),
364 scouter_version: self.scouter_version.clone(),
365 drift_type: self.config.drift_type.clone(),
366 }
367 }
368
369 fn to_value(&self) -> Value {
370 serde_json::to_value(self).unwrap()
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::AlertThreshold;
378 use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
379
380 #[test]
381 fn test_drift_config() {
382 let mut drift_config = CustomMetricDriftConfig::new(
383 MISSING,
384 MISSING,
385 "0.1.0",
386 25,
387 CustomMetricAlertConfig::default(),
388 None,
389 )
390 .unwrap();
391 assert_eq!(drift_config.name, "__missing__");
392 assert_eq!(drift_config.space, "__missing__");
393 assert_eq!(drift_config.version, "0.1.0");
394 assert_eq!(
395 drift_config.alert_config.dispatch_config,
396 AlertDispatchConfig::default()
397 );
398
399 let test_slack_dispatch_config = SlackDispatchConfig {
400 channel: "test-channel".to_string(),
401 };
402 let new_alert_config = CustomMetricAlertConfig {
403 schedule: "0 0 * * * *".to_string(),
404 dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
405 ..Default::default()
406 };
407
408 drift_config
410 .update_config_args(
411 None,
412 Some("test".to_string()),
413 None,
414 None,
415 Some(new_alert_config),
416 )
417 .unwrap();
418
419 assert_eq!(drift_config.name, "test");
420 assert_eq!(
421 drift_config.alert_config.dispatch_config,
422 AlertDispatchConfig::Slack(test_slack_dispatch_config)
423 );
424 assert_eq!(
425 drift_config.alert_config.schedule,
426 "0 0 * * * *".to_string()
427 );
428 }
429
430 #[test]
431 fn test_custom_drift_profile() {
432 let alert_config = CustomMetricAlertConfig {
433 schedule: "0 0 * * * *".to_string(),
434 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
435 team: "test-team".to_string(),
436 priority: "P5".to_string(),
437 }),
438 ..Default::default()
439 };
440
441 let drift_config =
442 CustomMetricDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
443
444 let custom_metrics = vec![
445 CustomMetric::new("mae", 12.4, AlertThreshold::Above, Some(2.3)).unwrap(),
446 CustomMetric::new("accuracy", 0.85, AlertThreshold::Below, None).unwrap(),
447 ];
448
449 let profile = CustomDriftProfile::new(drift_config, custom_metrics).unwrap();
450 let _: Value =
451 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
452
453 assert_eq!(profile.metrics.len(), 2);
454 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
455 let conditions = profile.config.alert_config.alert_conditions.unwrap();
456 assert_eq!(conditions["mae"].alert_threshold, AlertThreshold::Above);
457 assert_eq!(conditions["mae"].delta, Some(2.3));
458 assert_eq!(
459 conditions["accuracy"].alert_threshold,
460 AlertThreshold::Below
461 );
462 assert_eq!(conditions["accuracy"].delta, None);
463 }
464}