1#![allow(clippy::useless_conversion)]
2use crate::error::{ProfileError, TypeError};
3use crate::psi::alert::PsiAlertConfig;
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::ProfileRequest;
6use crate::{
7 DispatchDriftConfig, DriftArgs, DriftType, FeatureMap, FileName, ProfileArgs, ProfileBaseArgs,
8 ProfileFuncs, DEFAULT_VERSION, MISSING,
9};
10use chrono::Utc;
11use core::fmt::Debug;
12use pyo3::prelude::*;
13use pyo3::types::PyDict;
14use serde::de::{self, MapAccess, Visitor};
15use serde::ser::SerializeStruct;
16use serde::{Deserialize, Deserializer, Serialize, Serializer};
17use serde_json::Value;
18use std::collections::{BTreeMap, HashMap};
19use std::path::PathBuf;
20use tracing::debug;
21
22#[pyclass(eq)]
23#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
24pub enum BinType {
25 Binary,
26 Numeric,
27 Category,
28}
29
30#[pyclass]
31#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
32pub struct PsiDriftConfig {
33 #[pyo3(get, set)]
34 pub space: String,
35
36 #[pyo3(get, set)]
37 pub name: String,
38
39 #[pyo3(get, set)]
40 pub version: String,
41
42 #[pyo3(get, set)]
43 pub alert_config: PsiAlertConfig,
44
45 #[pyo3(get)]
46 #[serde(default)]
47 pub feature_map: FeatureMap,
48
49 #[pyo3(get, set)]
50 #[serde(default = "default_drift_type")]
51 pub drift_type: DriftType,
52}
53
54fn default_drift_type() -> DriftType {
55 DriftType::Psi
56}
57
58impl PsiDriftConfig {
59 pub fn update_feature_map(&mut self, feature_map: FeatureMap) {
60 self.feature_map = feature_map;
61 }
62}
63
64#[pymethods]
65#[allow(clippy::too_many_arguments)]
66impl PsiDriftConfig {
67 #[new]
69 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, alert_config=PsiAlertConfig::default(), config_path=None))]
70 pub fn new(
71 space: &str,
72 name: &str,
73 version: &str,
74 alert_config: PsiAlertConfig,
75 config_path: Option<PathBuf>,
76 ) -> Result<Self, ProfileError> {
77 if let Some(config_path) = config_path {
78 let config = PsiDriftConfig::load_from_json_file(config_path);
79 return config;
80 }
81
82 if name == MISSING || space == MISSING {
83 debug!("Name and space were not provided. Defaulting to __missing__");
84 }
85
86 Ok(Self {
87 name: name.to_string(),
88 space: space.to_string(),
89 version: version.to_string(),
90 alert_config,
91 feature_map: FeatureMap::default(),
92 drift_type: DriftType::Psi,
93 })
94 }
95
96 #[staticmethod]
97 pub fn load_from_json_file(path: PathBuf) -> Result<PsiDriftConfig, ProfileError> {
98 let file = std::fs::read_to_string(&path)?;
101
102 Ok(serde_json::from_str(&file)?)
103 }
104
105 pub fn __str__(&self) -> String {
106 ProfileFuncs::__str__(self)
108 }
109
110 pub fn model_dump_json(&self) -> String {
111 ProfileFuncs::__json__(self)
113 }
114
115 #[allow(clippy::too_many_arguments)]
116 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
117 pub fn update_config_args(
118 &mut self,
119 space: Option<String>,
120 name: Option<String>,
121 version: Option<String>,
122 alert_config: Option<PsiAlertConfig>,
123 ) -> Result<(), TypeError> {
124 if name.is_some() {
125 self.name = name.ok_or(TypeError::MissingNameError)?;
126 }
127
128 if space.is_some() {
129 self.space = space.ok_or(TypeError::MissingSpaceError)?;
130 }
131
132 if version.is_some() {
133 self.version = version.ok_or(TypeError::MissingVersionError)?;
134 }
135
136 if alert_config.is_some() {
137 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
138 }
139
140 Ok(())
141 }
142}
143
144impl Default for PsiDriftConfig {
145 fn default() -> Self {
146 PsiDriftConfig {
147 name: "__missing__".to_string(),
148 space: "__missing__".to_string(),
149 version: DEFAULT_VERSION.to_string(),
150 feature_map: FeatureMap::default(),
151 alert_config: PsiAlertConfig::default(),
152 drift_type: DriftType::Psi,
153 }
154 }
155}
156impl DispatchDriftConfig for PsiDriftConfig {
159 fn get_drift_args(&self) -> DriftArgs {
160 DriftArgs {
161 name: self.name.clone(),
162 space: self.space.clone(),
163 version: self.version.clone(),
164 dispatch_config: self.alert_config.dispatch_config.clone(),
165 }
166 }
167}
168
169#[pyclass]
170#[derive(Debug, Clone, PartialEq)]
171pub struct Bin {
172 #[pyo3(get)]
173 pub id: usize,
174
175 #[pyo3(get)]
176 pub lower_limit: Option<f64>,
177
178 #[pyo3(get)]
179 pub upper_limit: Option<f64>,
180
181 #[pyo3(get)]
182 pub proportion: f64,
183}
184
185impl Serialize for Bin {
186 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
187 where
188 S: Serializer,
189 {
190 let mut state = serializer.serialize_struct("Bin", 4)?;
191 state.serialize_field("id", &self.id)?;
192
193 state.serialize_field(
194 "lower_limit",
195 &self.lower_limit.map(|v| {
196 if v.is_infinite() {
197 serde_json::Value::String(if v.is_sign_positive() {
198 "inf".to_string()
199 } else {
200 "-inf".to_string()
201 })
202 } else {
203 serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
204 }
205 }),
206 )?;
207 state.serialize_field(
208 "upper_limit",
209 &self.upper_limit.map(|v| {
210 if v.is_infinite() {
211 serde_json::Value::String(if v.is_sign_positive() {
212 "inf".to_string()
213 } else {
214 "-inf".to_string()
215 })
216 } else {
217 serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
218 }
219 }),
220 )?;
221 state.serialize_field("proportion", &self.proportion)?;
222 state.end()
223 }
224}
225
226impl<'de> Deserialize<'de> for Bin {
227 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
228 where
229 D: Deserializer<'de>,
230 {
231 #[derive(Deserialize)]
232 #[serde(untagged)]
233 enum NumberOrString {
234 Number(f64),
235 String(String),
236 }
237
238 #[derive(Deserialize)]
239 #[serde(field_identifier, rename_all = "snake_case")]
240 enum Field {
241 Id,
242 LowerLimit,
243 UpperLimit,
244 Proportion,
245 }
246
247 struct BinVisitor;
248
249 impl<'de> Visitor<'de> for BinVisitor {
250 type Value = Bin;
251
252 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
253 formatter.write_str("struct Bin")
254 }
255
256 fn visit_map<V>(self, mut map: V) -> Result<Bin, V::Error>
257 where
258 V: MapAccess<'de>,
259 {
260 let mut id = None;
261 let mut lower_limit = None;
262 let mut upper_limit = None;
263 let mut proportion = None;
264
265 while let Some(key) = map.next_key()? {
266 match key {
267 Field::Id => {
268 id = Some(map.next_value()?);
269 }
270 Field::LowerLimit => {
271 let val: Option<NumberOrString> = map.next_value()?;
272 lower_limit = Some(val.map(|v| match v {
273 NumberOrString::String(s) => match s.as_str() {
274 "inf" => f64::INFINITY,
275 "-inf" => f64::NEG_INFINITY,
276 _ => s.parse().unwrap(),
277 },
278 NumberOrString::Number(n) => n,
279 }));
280 }
281 Field::UpperLimit => {
282 let val: Option<NumberOrString> = map.next_value()?;
283 upper_limit = Some(val.map(|v| match v {
284 NumberOrString::String(s) => match s.as_str() {
285 "inf" => f64::INFINITY,
286 "-inf" => f64::NEG_INFINITY,
287 _ => s.parse().unwrap(),
288 },
289 NumberOrString::Number(n) => n,
290 }));
291 }
292 Field::Proportion => {
293 proportion = Some(map.next_value()?);
294 }
295 }
296 }
297
298 Ok(Bin {
299 id: id.ok_or_else(|| de::Error::missing_field("id"))?,
300 lower_limit: lower_limit
301 .ok_or_else(|| de::Error::missing_field("lower_limit"))?,
302 upper_limit: upper_limit
303 .ok_or_else(|| de::Error::missing_field("upper_limit"))?,
304 proportion: proportion.ok_or_else(|| de::Error::missing_field("proportion"))?,
305 })
306 }
307 }
308
309 const FIELDS: &[&str] = &["id", "lower_limit", "upper_limit", "proportion"];
310 deserializer.deserialize_struct("Bin", FIELDS, BinVisitor)
311 }
312}
313
314#[pyclass]
315#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
316pub struct PsiFeatureDriftProfile {
317 #[pyo3(get)]
318 pub id: String,
319
320 #[pyo3(get)]
321 pub bins: Vec<Bin>,
322
323 #[pyo3(get)]
324 pub timestamp: chrono::DateTime<Utc>,
325
326 #[pyo3(get)]
327 pub bin_type: BinType,
328}
329
330#[pyclass]
331#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
332pub struct PsiDriftProfile {
333 #[pyo3(get)]
334 pub features: HashMap<String, PsiFeatureDriftProfile>,
335
336 #[pyo3(get)]
337 pub config: PsiDriftConfig,
338
339 #[pyo3(get)]
340 pub scouter_version: String,
341}
342
343impl PsiDriftProfile {
344 pub fn new(
345 features: HashMap<String, PsiFeatureDriftProfile>,
346 config: PsiDriftConfig,
347 scouter_version: Option<String>,
348 ) -> Self {
349 let scouter_version = scouter_version.unwrap_or(env!("CARGO_PKG_VERSION").to_string());
350 Self {
351 features,
352 config,
353 scouter_version,
354 }
355 }
356}
357
358#[pymethods]
359impl PsiDriftProfile {
360 pub fn __str__(&self) -> String {
361 ProfileFuncs::__str__(self)
363 }
364
365 pub fn model_dump_json(&self) -> String {
366 ProfileFuncs::__json__(self)
368 }
369 #[allow(clippy::useless_conversion)]
371 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
372 let json_str = serde_json::to_string(&self)?;
373
374 let json_value: Value = serde_json::from_str(&json_str)?;
375
376 let dict = PyDict::new(py);
378
379 json_to_pyobject(py, &json_value, &dict)?;
381
382 Ok(dict.into())
384 }
385
386 #[staticmethod]
387 pub fn from_file(path: PathBuf) -> Result<PsiDriftProfile, ProfileError> {
388 let file = std::fs::read_to_string(&path)?;
389
390 Ok(serde_json::from_str(&file)?)
391 }
392
393 #[staticmethod]
394 pub fn model_validate(data: &Bound<'_, PyDict>) -> PsiDriftProfile {
395 let json_value = pyobject_to_json(data).unwrap();
396
397 let string = serde_json::to_string(&json_value).unwrap();
398 serde_json::from_str(&string).expect("Failed to load drift profile")
399 }
400
401 #[staticmethod]
402 pub fn model_validate_json(json_string: String) -> PsiDriftProfile {
403 serde_json::from_str(&json_string).expect("Failed to load monitor profile")
405 }
406
407 #[pyo3(signature = (path=None))]
409 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
410 Ok(ProfileFuncs::save_to_json(
411 self,
412 path,
413 FileName::PsiDriftProfile.to_str(),
414 )?)
415 }
416
417 #[allow(clippy::too_many_arguments)]
418 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
419 pub fn update_config_args(
420 &mut self,
421 space: Option<String>,
422 name: Option<String>,
423 version: Option<String>,
424 alert_config: Option<PsiAlertConfig>,
425 ) -> Result<(), TypeError> {
426 self.config
427 .update_config_args(space, name, version, alert_config)
428 }
429
430 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
432 Ok(ProfileRequest {
433 space: self.config.space.clone(),
434 profile: self.model_dump_json(),
435 drift_type: self.config.drift_type.clone(),
436 })
437 }
438}
439
440#[pyclass]
441#[derive(Debug, Serialize, Deserialize, Clone)]
442pub struct PsiDriftMap {
443 #[pyo3(get)]
444 pub features: HashMap<String, f64>,
445
446 #[pyo3(get)]
447 pub name: String,
448
449 #[pyo3(get)]
450 pub space: String,
451
452 #[pyo3(get)]
453 pub version: String,
454}
455
456impl PsiDriftMap {
457 pub fn new(space: String, name: String, version: String) -> Self {
458 Self {
459 features: HashMap::new(),
460 name,
461 space,
462 version,
463 }
464 }
465}
466
467#[pymethods]
468#[allow(clippy::new_without_default)]
469impl PsiDriftMap {
470 pub fn __str__(&self) -> String {
471 ProfileFuncs::__str__(self)
473 }
474
475 pub fn model_dump_json(&self) -> String {
476 ProfileFuncs::__json__(self)
478 }
479
480 #[staticmethod]
481 pub fn model_validate_json(json_string: String) -> Result<PsiDriftMap, ProfileError> {
482 Ok(serde_json::from_str(&json_string)?)
484 }
485
486 #[pyo3(signature = (path=None))]
487 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
488 Ok(ProfileFuncs::save_to_json(
489 self,
490 path,
491 FileName::PsiDriftMap.to_str(),
492 )?)
493 }
494}
495
496impl ProfileBaseArgs for PsiDriftProfile {
498 fn get_base_args(&self) -> ProfileArgs {
500 ProfileArgs {
501 name: self.config.name.clone(),
502 space: self.config.space.clone(),
503 version: self.config.version.clone(),
504 schedule: self.config.alert_config.schedule.clone(),
505 scouter_version: self.scouter_version.clone(),
506 drift_type: self.config.drift_type.clone(),
507 }
508 }
509
510 fn to_value(&self) -> Value {
512 serde_json::to_value(self).unwrap()
513 }
514}
515
516#[derive(Debug, Clone, Serialize, Deserialize)]
517pub struct FeatureBinProportion {
518 pub feature: String,
519 pub bins: BTreeMap<usize, f64>,
520}
521
522#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct FeatureBinProportions {
524 pub features: BTreeMap<String, BTreeMap<usize, f64>>,
525}
526
527impl FromIterator<FeatureBinProportion> for FeatureBinProportions {
528 fn from_iter<T: IntoIterator<Item = FeatureBinProportion>>(iter: T) -> Self {
529 let mut feature_map: BTreeMap<String, BTreeMap<usize, f64>> = BTreeMap::new();
530 for feature in iter {
531 feature_map.insert(feature.feature, feature.bins);
532 }
533 FeatureBinProportions {
534 features: feature_map,
535 }
536 }
537}
538
539impl FeatureBinProportions {
540 pub fn from_features(features: Vec<FeatureBinProportion>) -> Self {
541 let mut feature_map: BTreeMap<String, BTreeMap<usize, f64>> = BTreeMap::new();
542 for feature in features {
543 feature_map.insert(feature.feature, feature.bins);
544 }
545 FeatureBinProportions {
546 features: feature_map,
547 }
548 }
549
550 pub fn get(&self, feature: &str, bin: &usize) -> Option<&f64> {
551 self.features.get(feature).and_then(|f| f.get(bin))
552 }
553
554 pub fn is_empty(&self) -> bool {
555 self.features.is_empty()
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 #[test]
564 fn test_drift_config() {
565 let mut drift_config = PsiDriftConfig::new(
566 MISSING,
567 MISSING,
568 DEFAULT_VERSION,
569 PsiAlertConfig::default(),
570 None,
571 )
572 .unwrap();
573 assert_eq!(drift_config.name, "__missing__");
574 assert_eq!(drift_config.space, "__missing__");
575 assert_eq!(drift_config.version, "0.1.0");
576 assert_eq!(drift_config.alert_config, PsiAlertConfig::default());
577
578 drift_config
580 .update_config_args(None, Some("test".to_string()), None, None)
581 .unwrap();
582
583 assert_eq!(drift_config.name, "test");
584 }
585}