1#![allow(clippy::useless_conversion)]
2use crate::custom::alert::{CustomMetric, CustomMetricAlertConfig};
3use crate::error::{ProfileError, TypeError};
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::ProfileRequest;
6use crate::{
7 DispatchDriftConfig, DriftArgs, DriftType, FileName, ProfileArgs, ProfileBaseArgs,
8 ProfileFuncs, DEFAULT_VERSION, MISSING,
9};
10use core::fmt::Debug;
11use pyo3::prelude::*;
12use pyo3::types::PyDict;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::collections::HashMap;
16use std::path::PathBuf;
17
18#[pyclass]
19#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
20pub struct CustomMetricDriftConfig {
21 #[pyo3(get, set)]
22 pub sample_size: usize,
23
24 #[pyo3(get, set)]
25 pub space: String,
26
27 #[pyo3(get, set)]
28 pub name: String,
29
30 #[pyo3(get, set)]
31 pub version: String,
32
33 #[pyo3(get, set)]
34 pub alert_config: CustomMetricAlertConfig,
35
36 #[pyo3(get, set)]
37 #[serde(default = "default_drift_type")]
38 pub drift_type: DriftType,
39}
40
41fn default_drift_type() -> DriftType {
42 DriftType::Custom
43}
44
45impl DispatchDriftConfig for CustomMetricDriftConfig {
46 fn get_drift_args(&self) -> DriftArgs {
47 DriftArgs {
48 name: self.name.clone(),
49 space: self.space.clone(),
50 version: self.version.clone(),
51 dispatch_config: self.alert_config.dispatch_config.clone(),
52 }
53 }
54}
55
56#[pymethods]
57#[allow(clippy::too_many_arguments)]
58impl CustomMetricDriftConfig {
59 #[new]
60 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, sample_size=25, alert_config=CustomMetricAlertConfig::default(), config_path=None))]
61 pub fn new(
62 space: &str,
63 name: &str,
64 version: &str,
65 sample_size: usize,
66 alert_config: CustomMetricAlertConfig,
67 config_path: Option<PathBuf>,
68 ) -> Result<Self, ProfileError> {
69 if let Some(config_path) = config_path {
70 let config = CustomMetricDriftConfig::load_from_json_file(config_path)?;
71 return Ok(config);
72 }
73
74 Ok(Self {
75 sample_size,
76 space: space.to_string(),
77 name: name.to_string(),
78 version: version.to_string(),
79 alert_config,
80 drift_type: DriftType::Custom,
81 })
82 }
83
84 #[staticmethod]
85 pub fn load_from_json_file(path: PathBuf) -> Result<CustomMetricDriftConfig, ProfileError> {
86 let file = std::fs::read_to_string(&path)?;
89
90 Ok(serde_json::from_str(&file)?)
91 }
92
93 pub fn __str__(&self) -> String {
94 ProfileFuncs::__str__(self)
96 }
97
98 pub fn model_dump_json(&self) -> String {
99 ProfileFuncs::__json__(self)
101 }
102
103 #[allow(clippy::too_many_arguments)]
104 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
105 pub fn update_config_args(
106 &mut self,
107 space: Option<String>,
108 name: Option<String>,
109 version: Option<String>,
110 alert_config: Option<CustomMetricAlertConfig>,
111 ) -> Result<(), TypeError> {
112 if name.is_some() {
113 self.name = name.ok_or(TypeError::MissingNameError)?;
114 }
115
116 if space.is_some() {
117 self.space = space.ok_or(TypeError::MissingSpaceError)?;
118 }
119
120 if version.is_some() {
121 self.version = version.ok_or(TypeError::MissingVersionError)?;
122 }
123
124 if alert_config.is_some() {
125 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
126 }
127
128 Ok(())
129 }
130}
131
132#[pyclass]
133#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
134pub struct CustomDriftProfile {
135 #[pyo3(get)]
136 pub config: CustomMetricDriftConfig,
137
138 #[pyo3(get)]
139 pub metrics: HashMap<String, f64>,
140
141 #[pyo3(get)]
142 pub scouter_version: String,
143}
144
145#[pymethods]
146impl CustomDriftProfile {
147 #[new]
148 #[pyo3(signature = (config, metrics, scouter_version=None))]
149 pub fn new(
150 mut config: CustomMetricDriftConfig,
151 metrics: Vec<CustomMetric>,
152 scouter_version: Option<String>,
153 ) -> Result<Self, ProfileError> {
154 if metrics.is_empty() {
155 return Err(TypeError::NoMetricsError.into());
156 }
157
158 config.alert_config.set_alert_conditions(&metrics);
159
160 let metric_vals = metrics.iter().map(|m| (m.name.clone(), m.value)).collect();
161
162 let scouter_version = scouter_version.unwrap_or(env!("CARGO_PKG_VERSION").to_string());
163
164 Ok(Self {
165 config,
166 metrics: metric_vals,
167 scouter_version,
168 })
169 }
170
171 pub fn __str__(&self) -> String {
172 ProfileFuncs::__str__(self)
174 }
175
176 pub fn model_dump_json(&self) -> String {
177 ProfileFuncs::__json__(self)
179 }
180
181 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
182 let json_str = serde_json::to_string(&self)?;
183
184 let json_value: Value = serde_json::from_str(&json_str)?;
185
186 let dict = PyDict::new(py);
188
189 json_to_pyobject(py, &json_value, &dict)?;
191
192 Ok(dict.into())
194 }
195
196 #[pyo3(signature = (path=None))]
198 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
199 Ok(ProfileFuncs::save_to_json(
200 self,
201 path,
202 FileName::CustomDriftProfile.to_str(),
203 )?)
204 }
205
206 #[staticmethod]
207 pub fn model_validate(data: &Bound<'_, PyDict>) -> CustomDriftProfile {
208 let json_value = pyobject_to_json(data).unwrap();
209
210 let string = serde_json::to_string(&json_value).unwrap();
211 serde_json::from_str(&string).expect("Failed to load drift profile")
212 }
213
214 #[staticmethod]
215 pub fn model_validate_json(json_string: String) -> CustomDriftProfile {
216 serde_json::from_str(&json_string).expect("Failed to load monitor profile")
218 }
219
220 #[staticmethod]
221 pub fn from_file(path: PathBuf) -> Result<CustomDriftProfile, ProfileError> {
222 let file = std::fs::read_to_string(&path)?;
223
224 Ok(serde_json::from_str(&file)?)
225 }
226
227 #[allow(clippy::too_many_arguments)]
228 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
229 pub fn update_config_args(
230 &mut self,
231 space: Option<String>,
232 name: Option<String>,
233 version: Option<String>,
234 alert_config: Option<CustomMetricAlertConfig>,
235 ) -> Result<(), TypeError> {
236 self.config
237 .update_config_args(space, name, version, alert_config)
238 }
239
240 #[getter]
241 pub fn custom_metrics(&self) -> Result<Vec<CustomMetric>, ProfileError> {
242 let alert_conditions = &self
243 .config
244 .alert_config
245 .alert_conditions
246 .clone()
247 .ok_or(ProfileError::CustomThresholdNotSetError)?;
248
249 Ok(self
250 .metrics
251 .iter()
252 .map(|(name, value)| {
253 let alert = alert_conditions
255 .get(name)
256 .ok_or(ProfileError::CustomAlertThresholdNotFound)
257 .unwrap();
258 CustomMetric::new(
259 name,
260 *value,
261 alert.alert_threshold.clone(),
262 alert.alert_threshold_value,
263 )
264 .unwrap()
265 })
266 .collect())
267 }
268
269 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
271 Ok(ProfileRequest {
272 space: self.config.space.clone(),
273 profile: self.model_dump_json(),
274 drift_type: self.config.drift_type.clone(),
275 })
276 }
277}
278
279impl ProfileBaseArgs for CustomDriftProfile {
280 fn get_base_args(&self) -> ProfileArgs {
281 ProfileArgs {
282 name: self.config.name.clone(),
283 space: self.config.space.clone(),
284 version: self.config.version.clone(),
285 schedule: self.config.alert_config.schedule.clone(),
286 scouter_version: self.scouter_version.clone(),
287 drift_type: self.config.drift_type.clone(),
288 }
289 }
290
291 fn to_value(&self) -> Value {
292 serde_json::to_value(self).unwrap()
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::custom::alert::AlertThreshold;
300 use crate::{AlertDispatchConfig, OpsGenieDispatchConfig, SlackDispatchConfig};
301
302 #[test]
303 fn test_drift_config() {
304 let mut drift_config = CustomMetricDriftConfig::new(
305 MISSING,
306 MISSING,
307 "0.1.0",
308 25,
309 CustomMetricAlertConfig::default(),
310 None,
311 )
312 .unwrap();
313 assert_eq!(drift_config.name, "__missing__");
314 assert_eq!(drift_config.space, "__missing__");
315 assert_eq!(drift_config.version, "0.1.0");
316 assert_eq!(
317 drift_config.alert_config.dispatch_config,
318 AlertDispatchConfig::default()
319 );
320
321 let test_slack_dispatch_config = SlackDispatchConfig {
322 channel: "test-channel".to_string(),
323 };
324 let new_alert_config = CustomMetricAlertConfig {
325 schedule: "0 0 * * * *".to_string(),
326 dispatch_config: AlertDispatchConfig::Slack(test_slack_dispatch_config.clone()),
327 ..Default::default()
328 };
329
330 drift_config
332 .update_config_args(None, Some("test".to_string()), None, Some(new_alert_config))
333 .unwrap();
334
335 assert_eq!(drift_config.name, "test");
336 assert_eq!(
337 drift_config.alert_config.dispatch_config,
338 AlertDispatchConfig::Slack(test_slack_dispatch_config)
339 );
340 assert_eq!(
341 drift_config.alert_config.schedule,
342 "0 0 * * * *".to_string()
343 );
344 }
345
346 #[test]
347 fn test_custom_drift_profile() {
348 let alert_config = CustomMetricAlertConfig {
349 schedule: "0 0 * * * *".to_string(),
350 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
351 team: "test-team".to_string(),
352 priority: "P5".to_string(),
353 }),
354 ..Default::default()
355 };
356
357 let drift_config =
358 CustomMetricDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
359
360 let custom_metrics = vec![
361 CustomMetric::new("mae", 12.4, AlertThreshold::Above, Some(2.3)).unwrap(),
362 CustomMetric::new("accuracy", 0.85, AlertThreshold::Below, None).unwrap(),
363 ];
364
365 let profile = CustomDriftProfile::new(drift_config, custom_metrics, None).unwrap();
366 let _: Value =
367 serde_json::from_str(&profile.model_dump_json()).expect("Failed to parse actual JSON");
368
369 assert_eq!(profile.metrics.len(), 2);
370 assert_eq!(profile.scouter_version, env!("CARGO_PKG_VERSION"));
371 let conditions = profile.config.alert_config.alert_conditions.unwrap();
372 assert_eq!(conditions["mae"].alert_threshold, AlertThreshold::Above);
373 assert_eq!(conditions["mae"].alert_threshold_value, Some(2.3));
374 assert_eq!(
375 conditions["accuracy"].alert_threshold,
376 AlertThreshold::Below
377 );
378 assert_eq!(conditions["accuracy"].alert_threshold_value, None);
379 }
380}