1#![allow(clippy::useless_conversion)]
2use crate::error::{ProfileError, TypeError};
3use crate::psi::alert::PsiAlertConfig;
4use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version};
5use crate::ProfileRequest;
6use crate::VersionRequest;
7use crate::{
8 DispatchDriftConfig, DriftArgs, DriftType, FeatureMap, FileName, ProfileArgs, ProfileBaseArgs,
9 ProfileFuncs, DEFAULT_VERSION, MISSING,
10};
11use chrono::Utc;
12use core::fmt::Debug;
13use pyo3::prelude::*;
14use pyo3::types::PyDict;
15use scouter_semver::VersionType;
16use serde::de::{self, MapAccess, Visitor};
17use serde::ser::SerializeStruct;
18use serde::{Deserialize, Deserializer, Serialize, Serializer};
19use serde_json::Value;
20use std::collections::{BTreeMap, HashMap};
21use std::path::PathBuf;
22use tracing::debug;
23
24#[pyclass(eq)]
25#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
26pub enum BinType {
27 Numeric,
28 Category,
29}
30
31#[pyclass]
32#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
33pub struct PsiDriftConfig {
34 #[pyo3(get, set)]
35 pub space: String,
36
37 #[pyo3(get, set)]
38 pub name: String,
39
40 #[pyo3(get, set)]
41 pub version: String,
42
43 #[pyo3(get, set)]
44 pub alert_config: PsiAlertConfig,
45
46 #[pyo3(get)]
47 #[serde(default)]
48 pub feature_map: FeatureMap,
49
50 #[pyo3(get, set)]
51 #[serde(default = "default_drift_type")]
52 pub drift_type: DriftType,
53
54 #[pyo3(get, set)]
55 pub categorical_features: Option<Vec<String>>,
56}
57
58fn default_drift_type() -> DriftType {
59 DriftType::Psi
60}
61
62impl PsiDriftConfig {
63 pub fn update_feature_map(&mut self, feature_map: FeatureMap) {
64 self.feature_map = feature_map;
65 }
66}
67
68#[pymethods]
69#[allow(clippy::too_many_arguments)]
70impl PsiDriftConfig {
71 #[new]
72 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, alert_config=PsiAlertConfig::default(), config_path=None, categorical_features=None))]
73 pub fn new(
74 space: &str,
75 name: &str,
76 version: &str,
77 alert_config: PsiAlertConfig,
78 config_path: Option<PathBuf>,
79 categorical_features: Option<Vec<String>>,
80 ) -> Result<Self, ProfileError> {
81 if let Some(config_path) = config_path {
82 let config = PsiDriftConfig::load_from_json_file(config_path);
83 return config;
84 }
85
86 if name == MISSING || space == MISSING {
87 debug!("Name and space were not provided. Defaulting to __missing__");
88 }
89
90 Ok(Self {
91 name: name.to_string(),
92 space: space.to_string(),
93 version: version.to_string(),
94 alert_config,
95 categorical_features,
96 feature_map: FeatureMap::default(),
97 drift_type: DriftType::Psi,
98 })
99 }
100
101 #[staticmethod]
102 pub fn load_from_json_file(path: PathBuf) -> Result<PsiDriftConfig, ProfileError> {
103 let file = std::fs::read_to_string(&path)?;
106
107 Ok(serde_json::from_str(&file)?)
108 }
109
110 pub fn __str__(&self) -> String {
111 ProfileFuncs::__str__(self)
113 }
114
115 pub fn model_dump_json(&self) -> String {
116 ProfileFuncs::__json__(self)
118 }
119
120 #[allow(clippy::too_many_arguments)]
121 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
122 pub fn update_config_args(
123 &mut self,
124 space: Option<String>,
125 name: Option<String>,
126 version: Option<String>,
127 alert_config: Option<PsiAlertConfig>,
128 ) -> Result<(), TypeError> {
129 if name.is_some() {
130 self.name = name.ok_or(TypeError::MissingNameError)?;
131 }
132
133 if space.is_some() {
134 self.space = space.ok_or(TypeError::MissingSpaceError)?;
135 }
136
137 if version.is_some() {
138 self.version = version.ok_or(TypeError::MissingVersionError)?;
139 }
140
141 if alert_config.is_some() {
142 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
143 }
144
145 Ok(())
146 }
147}
148
149impl Default for PsiDriftConfig {
150 fn default() -> Self {
151 PsiDriftConfig {
152 name: "__missing__".to_string(),
153 space: "__missing__".to_string(),
154 version: DEFAULT_VERSION.to_string(),
155 feature_map: FeatureMap::default(),
156 alert_config: PsiAlertConfig::default(),
157 drift_type: DriftType::Psi,
158 categorical_features: None,
159 }
160 }
161}
162impl DispatchDriftConfig for PsiDriftConfig {
165 fn get_drift_args(&self) -> DriftArgs {
166 DriftArgs {
167 name: self.name.clone(),
168 space: self.space.clone(),
169 version: self.version.clone(),
170 dispatch_config: self.alert_config.dispatch_config.clone(),
171 }
172 }
173}
174
175#[pyclass]
176#[derive(Debug, Clone, PartialEq)]
177pub struct Bin {
178 #[pyo3(get)]
179 pub id: usize,
180
181 #[pyo3(get)]
182 pub lower_limit: Option<f64>,
183
184 #[pyo3(get)]
185 pub upper_limit: Option<f64>,
186
187 #[pyo3(get)]
188 pub proportion: f64,
189}
190
191impl Serialize for Bin {
192 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193 where
194 S: Serializer,
195 {
196 let mut state = serializer.serialize_struct("Bin", 4)?;
197 state.serialize_field("id", &self.id)?;
198
199 state.serialize_field(
200 "lower_limit",
201 &self.lower_limit.map(|v| {
202 if v.is_infinite() {
203 serde_json::Value::String(if v.is_sign_positive() {
204 "inf".to_string()
205 } else {
206 "-inf".to_string()
207 })
208 } else {
209 serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
210 }
211 }),
212 )?;
213 state.serialize_field(
214 "upper_limit",
215 &self.upper_limit.map(|v| {
216 if v.is_infinite() {
217 serde_json::Value::String(if v.is_sign_positive() {
218 "inf".to_string()
219 } else {
220 "-inf".to_string()
221 })
222 } else {
223 serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
224 }
225 }),
226 )?;
227 state.serialize_field("proportion", &self.proportion)?;
228 state.end()
229 }
230}
231
232impl<'de> Deserialize<'de> for Bin {
233 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
234 where
235 D: Deserializer<'de>,
236 {
237 #[derive(Deserialize)]
238 #[serde(untagged)]
239 enum NumberOrString {
240 Number(f64),
241 String(String),
242 }
243
244 #[derive(Deserialize)]
245 #[serde(field_identifier, rename_all = "snake_case")]
246 enum Field {
247 Id,
248 LowerLimit,
249 UpperLimit,
250 Proportion,
251 }
252
253 struct BinVisitor;
254
255 impl<'de> Visitor<'de> for BinVisitor {
256 type Value = Bin;
257
258 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
259 formatter.write_str("struct Bin")
260 }
261
262 fn visit_map<V>(self, mut map: V) -> Result<Bin, V::Error>
263 where
264 V: MapAccess<'de>,
265 {
266 let mut id = None;
267 let mut lower_limit = None;
268 let mut upper_limit = None;
269 let mut proportion = None;
270
271 while let Some(key) = map.next_key()? {
272 match key {
273 Field::Id => {
274 id = Some(map.next_value()?);
275 }
276 Field::LowerLimit => {
277 let val: Option<NumberOrString> = map.next_value()?;
278 lower_limit = Some(val.map(|v| match v {
279 NumberOrString::String(s) => match s.as_str() {
280 "inf" => f64::INFINITY,
281 "-inf" => f64::NEG_INFINITY,
282 _ => s.parse().unwrap(),
283 },
284 NumberOrString::Number(n) => n,
285 }));
286 }
287 Field::UpperLimit => {
288 let val: Option<NumberOrString> = map.next_value()?;
289 upper_limit = Some(val.map(|v| match v {
290 NumberOrString::String(s) => match s.as_str() {
291 "inf" => f64::INFINITY,
292 "-inf" => f64::NEG_INFINITY,
293 _ => s.parse().unwrap(),
294 },
295 NumberOrString::Number(n) => n,
296 }));
297 }
298 Field::Proportion => {
299 proportion = Some(map.next_value()?);
300 }
301 }
302 }
303
304 Ok(Bin {
305 id: id.ok_or_else(|| de::Error::missing_field("id"))?,
306 lower_limit: lower_limit
307 .ok_or_else(|| de::Error::missing_field("lower_limit"))?,
308 upper_limit: upper_limit
309 .ok_or_else(|| de::Error::missing_field("upper_limit"))?,
310 proportion: proportion.ok_or_else(|| de::Error::missing_field("proportion"))?,
311 })
312 }
313 }
314
315 const FIELDS: &[&str] = &["id", "lower_limit", "upper_limit", "proportion"];
316 deserializer.deserialize_struct("Bin", FIELDS, BinVisitor)
317 }
318}
319
320#[pyclass]
321#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
322pub struct PsiFeatureDriftProfile {
323 #[pyo3(get)]
324 pub id: String,
325
326 #[pyo3(get)]
327 pub bins: Vec<Bin>,
328
329 #[pyo3(get)]
330 pub timestamp: chrono::DateTime<Utc>,
331
332 #[pyo3(get)]
333 pub bin_type: BinType,
334}
335
336#[pyclass]
337#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
338pub struct PsiDriftProfile {
339 #[pyo3(get)]
340 pub features: HashMap<String, PsiFeatureDriftProfile>,
341
342 #[pyo3(get)]
343 pub config: PsiDriftConfig,
344
345 #[pyo3(get)]
346 pub scouter_version: String,
347}
348
349impl PsiDriftProfile {
350 pub fn new(features: HashMap<String, PsiFeatureDriftProfile>, config: PsiDriftConfig) -> Self {
351 Self {
352 features,
353 config,
354 scouter_version: scouter_version(),
355 }
356 }
357}
358
359#[pymethods]
360impl PsiDriftProfile {
361 pub fn __str__(&self) -> String {
362 ProfileFuncs::__str__(self)
364 }
365
366 pub fn model_dump_json(&self) -> String {
367 ProfileFuncs::__json__(self)
369 }
370 #[allow(clippy::useless_conversion)]
372 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
373 let json_str = serde_json::to_string(&self)?;
374
375 let json_value: Value = serde_json::from_str(&json_str)?;
376
377 let dict = PyDict::new(py);
379
380 json_to_pyobject(py, &json_value, &dict)?;
382
383 Ok(dict.into())
385 }
386
387 #[staticmethod]
388 pub fn from_file(path: PathBuf) -> Result<PsiDriftProfile, ProfileError> {
389 let file = std::fs::read_to_string(&path)?;
390
391 Ok(serde_json::from_str(&file)?)
392 }
393
394 #[staticmethod]
395 pub fn model_validate(data: &Bound<'_, PyDict>) -> PsiDriftProfile {
396 let json_value = pyobject_to_json(data).unwrap();
397
398 let string = serde_json::to_string(&json_value).unwrap();
399 serde_json::from_str(&string).expect("Failed to load drift profile")
400 }
401
402 #[staticmethod]
403 pub fn model_validate_json(json_string: String) -> PsiDriftProfile {
404 serde_json::from_str(&json_string).expect("Failed to load monitor profile")
406 }
407
408 #[pyo3(signature = (path=None))]
410 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
411 Ok(ProfileFuncs::save_to_json(
412 self,
413 path,
414 FileName::PsiDriftProfile.to_str(),
415 )?)
416 }
417
418 #[allow(clippy::too_many_arguments)]
419 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
420 pub fn update_config_args(
421 &mut self,
422 space: Option<String>,
423 name: Option<String>,
424 version: Option<String>,
425 alert_config: Option<PsiAlertConfig>,
426 ) -> Result<(), TypeError> {
427 self.config
428 .update_config_args(space, name, version, alert_config)
429 }
430
431 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
433 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
434 None
435 } else {
436 Some(self.config.version.clone())
437 };
438
439 Ok(ProfileRequest {
440 space: self.config.space.clone(),
441 profile: self.model_dump_json(),
442 drift_type: self.config.drift_type.clone(),
443 version_request: VersionRequest {
444 version,
445 version_type: VersionType::Minor,
446 pre_tag: None,
447 build_tag: None,
448 },
449 })
450 }
451}
452
453#[pyclass]
454#[derive(Debug, Serialize, Deserialize, Clone)]
455pub struct PsiDriftMap {
456 #[pyo3(get)]
457 pub features: HashMap<String, f64>,
458
459 #[pyo3(get)]
460 pub name: String,
461
462 #[pyo3(get)]
463 pub space: String,
464
465 #[pyo3(get)]
466 pub version: String,
467}
468
469impl PsiDriftMap {
470 pub fn new(space: String, name: String, version: String) -> Self {
471 Self {
472 features: HashMap::new(),
473 name,
474 space,
475 version,
476 }
477 }
478}
479
480#[pymethods]
481#[allow(clippy::new_without_default)]
482impl PsiDriftMap {
483 pub fn __str__(&self) -> String {
484 ProfileFuncs::__str__(self)
486 }
487
488 pub fn model_dump_json(&self) -> String {
489 ProfileFuncs::__json__(self)
491 }
492
493 #[staticmethod]
494 pub fn model_validate_json(json_string: String) -> Result<PsiDriftMap, ProfileError> {
495 Ok(serde_json::from_str(&json_string)?)
497 }
498
499 #[pyo3(signature = (path=None))]
500 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
501 Ok(ProfileFuncs::save_to_json(
502 self,
503 path,
504 FileName::PsiDriftMap.to_str(),
505 )?)
506 }
507}
508
509impl ProfileBaseArgs for PsiDriftProfile {
511 fn get_base_args(&self) -> ProfileArgs {
513 ProfileArgs {
514 name: self.config.name.clone(),
515 space: self.config.space.clone(),
516 version: Some(self.config.version.clone()),
517 schedule: self.config.alert_config.schedule.clone(),
518 scouter_version: self.scouter_version.clone(),
519 drift_type: self.config.drift_type.clone(),
520 }
521 }
522
523 fn to_value(&self) -> Value {
525 serde_json::to_value(self).unwrap()
526 }
527}
528
529#[derive(Debug, Clone, Serialize, Deserialize)]
530pub struct DistributionData {
531 pub sample_size: u64,
532 pub bins: BTreeMap<usize, f64>,
533}
534
535#[derive(Debug, Clone, Serialize, Deserialize)]
536pub struct FeatureDistributions {
537 pub distributions: BTreeMap<String, DistributionData>,
538}
539
540impl FeatureDistributions {
541 pub fn is_empty(&self) -> bool {
542 self.distributions.is_empty()
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549
550 #[test]
551 fn test_drift_config() {
552 let mut drift_config = PsiDriftConfig::new(
553 MISSING,
554 MISSING,
555 DEFAULT_VERSION,
556 PsiAlertConfig::default(),
557 None,
558 None,
559 )
560 .unwrap();
561 assert_eq!(drift_config.name, "__missing__");
562 assert_eq!(drift_config.space, "__missing__");
563 assert_eq!(drift_config.version, "0.0.0");
564 assert_eq!(drift_config.alert_config, PsiAlertConfig::default());
565
566 drift_config
568 .update_config_args(None, Some("test".to_string()), None, None)
569 .unwrap();
570
571 assert_eq!(drift_config.name, "test");
572 }
573}