1#![allow(clippy::useless_conversion)]
2use crate::custom::alert::{CustomMetric, CustomMetricAlertConfig};
3use crate::error::{ProfileError, TypeError};
4use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version};
5use crate::ProfileRequest;
6use crate::{
7 DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
8 PyHelperFuncs, VersionRequest, DEFAULT_VERSION, MISSING,
9};
10use core::fmt::Debug;
11use pyo3::prelude::*;
12use pyo3::types::PyDict;
13use scouter_semver::VersionType;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::collections::HashMap;
17use std::path::PathBuf;
18
19#[pyclass]
20#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
21pub struct CustomMetricDriftConfig {
22 #[pyo3(get, set)]
23 pub sample_size: usize,
24
25 #[pyo3(get, set)]
26 pub space: String,
27
28 #[pyo3(get, set)]
29 pub name: String,
30
31 #[pyo3(get, set)]
32 pub version: String,
33
34 #[pyo3(get, set)]
35 pub alert_config: CustomMetricAlertConfig,
36
37 #[pyo3(get, set)]
38 #[serde(default = "default_drift_type")]
39 pub drift_type: DriftType,
40}
41
42fn default_drift_type() -> DriftType {
43 DriftType::Custom
44}
45
46impl DispatchDriftConfig for CustomMetricDriftConfig {
47 fn get_drift_args(&self) -> DriftArgs {
48 DriftArgs {
49 name: self.name.clone(),
50 space: self.space.clone(),
51 version: self.version.clone(),
52 dispatch_config: self.alert_config.dispatch_config.clone(),
53 }
54 }
55}
56
57#[pymethods]
58#[allow(clippy::too_many_arguments)]
59impl CustomMetricDriftConfig {
60 #[new]
61 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_size=25, alert_config=CustomMetricAlertConfig::default(), config_path=None))]
62 pub fn new(
63 space: &str,
64 name: &str,
65 version: &str,
66 sample_size: usize,
67 alert_config: CustomMetricAlertConfig,
68 config_path: Option<PathBuf>,
69 ) -> Result<Self, ProfileError> {
70 if let Some(config_path) = config_path {
71 let config = CustomMetricDriftConfig::load_from_json_file(config_path)?;
72 return Ok(config);
73 }
74
75 Ok(Self {
76 sample_size,
77 space: space.to_string(),
78 name: name.to_string(),
79 version: version.to_string(),
80 alert_config,
81 drift_type: DriftType::Custom,
82 })
83 }
84
85 #[staticmethod]
86 pub fn load_from_json_file(path: PathBuf) -> Result<CustomMetricDriftConfig, ProfileError> {
87 let file = std::fs::read_to_string(&path)?;
90
91 Ok(serde_json::from_str(&file)?)
92 }
93
94 pub fn __str__(&self) -> String {
95 PyHelperFuncs::__str__(self)
97 }
98
99 pub fn model_dump_json(&self) -> String {
100 PyHelperFuncs::__json__(self)
102 }
103
104 #[allow(clippy::too_many_arguments)]
105 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
106 pub fn update_config_args(
107 &mut self,
108 space: Option<String>,
109 name: Option<String>,
110 version: Option<String>,
111 alert_config: Option<CustomMetricAlertConfig>,
112 ) -> Result<(), TypeError> {
113 if name.is_some() {
114 self.name = name.ok_or(TypeError::MissingNameError)?;
115 }
116
117 if space.is_some() {
118 self.space = space.ok_or(TypeError::MissingSpaceError)?;
119 }
120
121 if version.is_some() {
122 self.version = version.ok_or(TypeError::MissingVersionError)?;
123 }
124
125 if alert_config.is_some() {
126 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
127 }
128
129 Ok(())
130 }
131}
132
133#[pyclass]
134#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
135pub struct CustomDriftProfile {
136 #[pyo3(get)]
137 pub config: CustomMetricDriftConfig,
138
139 #[pyo3(get)]
140 pub metrics: HashMap<String, f64>,
141
142 #[pyo3(get)]
143 pub scouter_version: String,
144}
145
146impl Default for CustomDriftProfile {
147 fn default() -> Self {
148 Self {
149 config: CustomMetricDriftConfig::new(
150 MISSING,
151 MISSING,
152 DEFAULT_VERSION,
153 25,
154 CustomMetricAlertConfig::default(),
155 None,
156 )
157 .unwrap(),
158 metrics: HashMap::new(),
159 scouter_version: scouter_version(),
160 }
161 }
162}
163
164#[pymethods]
165impl CustomDriftProfile {
166 #[new]
167 #[pyo3(signature = (config, metrics))]
168 pub fn new(
169 mut config: CustomMetricDriftConfig,
170 metrics: Vec<CustomMetric>,
171 ) -> Result<Self, ProfileError> {
172 if metrics.is_empty() {
173 return Err(TypeError::NoMetricsError.into());
174 }
175
176 config.alert_config.set_alert_conditions(&metrics);
177
178 let metric_vals = metrics.iter().map(|m| (m.name.clone(), m.value)).collect();
179
180 Ok(Self {
181 config,
182 metrics: metric_vals,
183 scouter_version: scouter_version(),
184 })
185 }
186
187 pub fn __str__(&self) -> String {
188 PyHelperFuncs::__str__(self)
190 }
191
192 pub fn model_dump_json(&self) -> String {
193 PyHelperFuncs::__json__(self)
195 }
196
197 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
198 let json_str = serde_json::to_string(&self)?;
199
200 let json_value: Value = serde_json::from_str(&json_str)?;
201
202 let dict = PyDict::new(py);
204
205 json_to_pyobject(py, &json_value, &dict)?;
207
208 Ok(dict.into())
210 }
211
212 #[pyo3(signature = (path=None))]
214 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
215 Ok(PyHelperFuncs::save_to_json(
216 self,
217 path,
218 FileName::CustomDriftProfile.to_str(),
219 )?)
220 }
221
222 #[staticmethod]
223 pub fn model_validate(data: &Bound<'_, PyDict>) -> CustomDriftProfile {
224 let json_value = pyobject_to_json(data).unwrap();
225
226 let string = serde_json::to_string(&json_value).unwrap();
227 serde_json::from_str(&string).expect("Failed to load drift profile")
228 }
229
230 #[staticmethod]
231 pub fn model_validate_json(json_string: String) -> CustomDriftProfile {
232 serde_json::from_str(&json_string).expect("Failed to load monitor profile")
234 }
235
236 #[staticmethod]
237 pub fn from_file(path: PathBuf) -> Result<CustomDriftProfile, ProfileError> {
238 let file = std::fs::read_to_string(&path)?;
239
240 Ok(serde_json::from_str(&file)?)
241 }
242
243 #[allow(clippy::too_many_arguments)]
244 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
245 pub fn update_config_args(
246 &mut self,
247 space: Option<String>,
248 name: Option<String>,
249 version: Option<String>,
250 alert_config: Option<CustomMetricAlertConfig>,
251 ) -> Result<(), TypeError> {
252 self.config
253 .update_config_args(space, name, version, alert_config)
254 }
255
256 #[getter]
257 pub fn custom_metrics(&self) -> Result<Vec<CustomMetric>, ProfileError> {
258 let alert_conditions = &self
259 .config
260 .alert_config
261 .alert_conditions
262 .clone()
263 .ok_or(ProfileError::CustomThresholdNotSetError)?;
264
265 Ok(self
266 .metrics
267 .iter()
268 .map(|(name, value)| {
269 let alert = alert_conditions
271 .get(name)
272 .ok_or(ProfileError::CustomAlertThresholdNotFound)
273 .unwrap();
274 CustomMetric::new(
275 name,
276 *value,
277 alert.alert_threshold.clone(),
278 alert.alert_threshold_value,
279 )
280 .unwrap()
281 })
282 .collect())
283 }
284
285 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
287 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
288 None
289 } else {
290 Some(self.config.version.clone())
291 };
292
293 Ok(ProfileRequest {
294 space: self.config.space.clone(),
295 profile: self.model_dump_json(),
296 drift_type: self.config.drift_type.clone(),
297 version_request: VersionRequest {
298 version,
299 version_type: VersionType::Minor,
300 pre_tag: None,
301 build_tag: None,
302 },
303 active: false,
304 deactivate_others: false,
305 })
306 }
307}
308
309impl ProfileBaseArgs for CustomDriftProfile {
310 fn get_base_args(&self) -> ProfileArgs {
311 ProfileArgs {
312 name: self.config.name.clone(),
313 space: self.config.space.clone(),
314 version: Some(self.config.version.clone()),
315 schedule: self.config.alert_config.schedule.clone(),
316 scouter_version: self.scouter_version.clone(),
317 drift_type: self.config.drift_type.clone(),
318 }
319 }
320
321 fn to_value(&self) -> Value {
322 serde_json::to_value(self).unwrap()
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use crate::AlertThreshold;
330 use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
331
332 #[test]
333 fn test_drift_config() {
334 let mut drift_config = CustomMetricDriftConfig::new(
335 MISSING,
336 MISSING,
337 "0.1.0",
338 25,
339 CustomMetricAlertConfig::default(),
340 None,
341 )
342 .unwrap();
343 assert_eq!(drift_config.name, "__missing__");
344 assert_eq!(drift_config.space, "__missing__");
345 assert_eq!(drift_config.version, "0.1.0");
346 assert_eq!(
347 drift_config.alert_config.dispatch_config,
348 AlertDispatchConfig::default()
349 );
350
351 let test_slack_dispatch_config = SlackDispatchConfig {
352 channel: "test-channel".to_string(),
353 };
354 let new_alert_config = CustomMetricAlertConfig {
355 schedule: "0 0 * * * *".to_string(),
356 dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
357 ..Default::default()
358 };
359
360 drift_config
362 .update_config_args(None, Some("test".to_string()), None, Some(new_alert_config))
363 .unwrap();
364
365 assert_eq!(drift_config.name, "test");
366 assert_eq!(
367 drift_config.alert_config.dispatch_config,
368 AlertDispatchConfig::Slack(test_slack_dispatch_config)
369 );
370 assert_eq!(
371 drift_config.alert_config.schedule,
372 "0 0 * * * *".to_string()
373 );
374 }
375
376 #[test]
377 fn test_custom_drift_profile() {
378 let alert_config = CustomMetricAlertConfig {
379 schedule: "0 0 * * * *".to_string(),
380 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
381 team: "test-team".to_string(),
382 priority: "P5".to_string(),
383 }),
384 ..Default::default()
385 };
386
387 let drift_config =
388 CustomMetricDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
389
390 let custom_metrics = vec![
391 CustomMetric::new("mae", 12.4, AlertThreshold::Above, Some(2.3)).unwrap(),
392 CustomMetric::new("accuracy", 0.85, AlertThreshold::Below, None).unwrap(),
393 ];
394
395 let profile = CustomDriftProfile::new(drift_config, custom_metrics).unwrap();
396 let _: Value =
397 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
398
399 assert_eq!(profile.metrics.len(), 2);
400 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
401 let conditions = profile.config.alert_config.alert_conditions.unwrap();
402 assert_eq!(conditions["mae"].alert_threshold, AlertThreshold::Above);
403 assert_eq!(conditions["mae"].alert_threshold_value, Some(2.3));
404 assert_eq!(
405 conditions["accuracy"].alert_threshold,
406 AlertThreshold::Below
407 );
408 assert_eq!(conditions["accuracy"].alert_threshold_value, None);
409 }
410}