1#![allow(clippy::useless_conversion)]
2use crate::error::{ProfileError, TypeError};
3use crate::psi::alert::PsiAlertConfig;
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::ProfileRequest;
6use crate::{
7 DispatchDriftConfig, DriftArgs, DriftType, FeatureMap, FileName, ProfileArgs, ProfileBaseArgs,
8 ProfileFuncs, DEFAULT_VERSION, MISSING,
9};
10use chrono::Utc;
11use core::fmt::Debug;
12use pyo3::prelude::*;
13use pyo3::types::PyDict;
14use serde::de::{self, MapAccess, Visitor};
15use serde::ser::SerializeStruct;
16use serde::{Deserialize, Deserializer, Serialize, Serializer};
17use serde_json::Value;
18use std::collections::{BTreeMap, HashMap};
19use std::path::PathBuf;
20use tracing::debug;
21
22#[pyclass(eq)]
23#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
24pub enum BinType {
25 Numeric,
26 Category,
27}
28
29#[pyclass]
30#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
31pub struct PsiDriftConfig {
32 #[pyo3(get, set)]
33 pub space: String,
34
35 #[pyo3(get, set)]
36 pub name: String,
37
38 #[pyo3(get, set)]
39 pub version: String,
40
41 #[pyo3(get, set)]
42 pub alert_config: PsiAlertConfig,
43
44 #[pyo3(get)]
45 #[serde(default)]
46 pub feature_map: FeatureMap,
47
48 #[pyo3(get, set)]
49 #[serde(default = "default_drift_type")]
50 pub drift_type: DriftType,
51
52 #[pyo3(get, set)]
53 pub categorical_features: Option<Vec<String>>,
54}
55
56fn default_drift_type() -> DriftType {
57 DriftType::Psi
58}
59
60impl PsiDriftConfig {
61 pub fn update_feature_map(&mut self, feature_map: FeatureMap) {
62 self.feature_map = feature_map;
63 }
64}
65
66#[pymethods]
67#[allow(clippy::too_many_arguments)]
68impl PsiDriftConfig {
69 #[new]
70 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, alert_config=PsiAlertConfig::default(), config_path=None, categorical_features=None))]
71 pub fn new(
72 space: &str,
73 name: &str,
74 version: &str,
75 alert_config: PsiAlertConfig,
76 config_path: Option<PathBuf>,
77 categorical_features: Option<Vec<String>>,
78 ) -> Result<Self, ProfileError> {
79 if let Some(config_path) = config_path {
80 let config = PsiDriftConfig::load_from_json_file(config_path);
81 return config;
82 }
83
84 if name == MISSING || space == MISSING {
85 debug!("Name and space were not provided. Defaulting to __missing__");
86 }
87
88 Ok(Self {
89 name: name.to_string(),
90 space: space.to_string(),
91 version: version.to_string(),
92 alert_config,
93 categorical_features,
94 feature_map: FeatureMap::default(),
95 drift_type: DriftType::Psi,
96 })
97 }
98
99 #[staticmethod]
100 pub fn load_from_json_file(path: PathBuf) -> Result<PsiDriftConfig, ProfileError> {
101 let file = std::fs::read_to_string(&path)?;
104
105 Ok(serde_json::from_str(&file)?)
106 }
107
108 pub fn __str__(&self) -> String {
109 ProfileFuncs::__str__(self)
111 }
112
113 pub fn model_dump_json(&self) -> String {
114 ProfileFuncs::__json__(self)
116 }
117
118 #[allow(clippy::too_many_arguments)]
119 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
120 pub fn update_config_args(
121 &mut self,
122 space: Option<String>,
123 name: Option<String>,
124 version: Option<String>,
125 alert_config: Option<PsiAlertConfig>,
126 ) -> Result<(), TypeError> {
127 if name.is_some() {
128 self.name = name.ok_or(TypeError::MissingNameError)?;
129 }
130
131 if space.is_some() {
132 self.space = space.ok_or(TypeError::MissingSpaceError)?;
133 }
134
135 if version.is_some() {
136 self.version = version.ok_or(TypeError::MissingVersionError)?;
137 }
138
139 if alert_config.is_some() {
140 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
141 }
142
143 Ok(())
144 }
145}
146
147impl Default for PsiDriftConfig {
148 fn default() -> Self {
149 PsiDriftConfig {
150 name: "__missing__".to_string(),
151 space: "__missing__".to_string(),
152 version: DEFAULT_VERSION.to_string(),
153 feature_map: FeatureMap::default(),
154 alert_config: PsiAlertConfig::default(),
155 drift_type: DriftType::Psi,
156 categorical_features: None,
157 }
158 }
159}
160impl DispatchDriftConfig for PsiDriftConfig {
163 fn get_drift_args(&self) -> DriftArgs {
164 DriftArgs {
165 name: self.name.clone(),
166 space: self.space.clone(),
167 version: self.version.clone(),
168 dispatch_config: self.alert_config.dispatch_config.clone(),
169 }
170 }
171}
172
173#[pyclass]
174#[derive(Debug, Clone, PartialEq)]
175pub struct Bin {
176 #[pyo3(get)]
177 pub id: usize,
178
179 #[pyo3(get)]
180 pub lower_limit: Option<f64>,
181
182 #[pyo3(get)]
183 pub upper_limit: Option<f64>,
184
185 #[pyo3(get)]
186 pub proportion: f64,
187}
188
189impl Serialize for Bin {
190 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
191 where
192 S: Serializer,
193 {
194 let mut state = serializer.serialize_struct("Bin", 4)?;
195 state.serialize_field("id", &self.id)?;
196
197 state.serialize_field(
198 "lower_limit",
199 &self.lower_limit.map(|v| {
200 if v.is_infinite() {
201 serde_json::Value::String(if v.is_sign_positive() {
202 "inf".to_string()
203 } else {
204 "-inf".to_string()
205 })
206 } else {
207 serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
208 }
209 }),
210 )?;
211 state.serialize_field(
212 "upper_limit",
213 &self.upper_limit.map(|v| {
214 if v.is_infinite() {
215 serde_json::Value::String(if v.is_sign_positive() {
216 "inf".to_string()
217 } else {
218 "-inf".to_string()
219 })
220 } else {
221 serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
222 }
223 }),
224 )?;
225 state.serialize_field("proportion", &self.proportion)?;
226 state.end()
227 }
228}
229
230impl<'de> Deserialize<'de> for Bin {
231 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
232 where
233 D: Deserializer<'de>,
234 {
235 #[derive(Deserialize)]
236 #[serde(untagged)]
237 enum NumberOrString {
238 Number(f64),
239 String(String),
240 }
241
242 #[derive(Deserialize)]
243 #[serde(field_identifier, rename_all = "snake_case")]
244 enum Field {
245 Id,
246 LowerLimit,
247 UpperLimit,
248 Proportion,
249 }
250
251 struct BinVisitor;
252
253 impl<'de> Visitor<'de> for BinVisitor {
254 type Value = Bin;
255
256 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
257 formatter.write_str("struct Bin")
258 }
259
260 fn visit_map<V>(self, mut map: V) -> Result<Bin, V::Error>
261 where
262 V: MapAccess<'de>,
263 {
264 let mut id = None;
265 let mut lower_limit = None;
266 let mut upper_limit = None;
267 let mut proportion = None;
268
269 while let Some(key) = map.next_key()? {
270 match key {
271 Field::Id => {
272 id = Some(map.next_value()?);
273 }
274 Field::LowerLimit => {
275 let val: Option<NumberOrString> = map.next_value()?;
276 lower_limit = Some(val.map(|v| match v {
277 NumberOrString::String(s) => match s.as_str() {
278 "inf" => f64::INFINITY,
279 "-inf" => f64::NEG_INFINITY,
280 _ => s.parse().unwrap(),
281 },
282 NumberOrString::Number(n) => n,
283 }));
284 }
285 Field::UpperLimit => {
286 let val: Option<NumberOrString> = map.next_value()?;
287 upper_limit = Some(val.map(|v| match v {
288 NumberOrString::String(s) => match s.as_str() {
289 "inf" => f64::INFINITY,
290 "-inf" => f64::NEG_INFINITY,
291 _ => s.parse().unwrap(),
292 },
293 NumberOrString::Number(n) => n,
294 }));
295 }
296 Field::Proportion => {
297 proportion = Some(map.next_value()?);
298 }
299 }
300 }
301
302 Ok(Bin {
303 id: id.ok_or_else(|| de::Error::missing_field("id"))?,
304 lower_limit: lower_limit
305 .ok_or_else(|| de::Error::missing_field("lower_limit"))?,
306 upper_limit: upper_limit
307 .ok_or_else(|| de::Error::missing_field("upper_limit"))?,
308 proportion: proportion.ok_or_else(|| de::Error::missing_field("proportion"))?,
309 })
310 }
311 }
312
313 const FIELDS: &[&str] = &["id", "lower_limit", "upper_limit", "proportion"];
314 deserializer.deserialize_struct("Bin", FIELDS, BinVisitor)
315 }
316}
317
318#[pyclass]
319#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
320pub struct PsiFeatureDriftProfile {
321 #[pyo3(get)]
322 pub id: String,
323
324 #[pyo3(get)]
325 pub bins: Vec<Bin>,
326
327 #[pyo3(get)]
328 pub timestamp: chrono::DateTime<Utc>,
329
330 #[pyo3(get)]
331 pub bin_type: BinType,
332}
333
334#[pyclass]
335#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
336pub struct PsiDriftProfile {
337 #[pyo3(get)]
338 pub features: HashMap<String, PsiFeatureDriftProfile>,
339
340 #[pyo3(get)]
341 pub config: PsiDriftConfig,
342
343 #[pyo3(get)]
344 pub scouter_version: String,
345}
346
347impl PsiDriftProfile {
348 pub fn new(
349 features: HashMap<String, PsiFeatureDriftProfile>,
350 config: PsiDriftConfig,
351 scouter_version: Option<String>,
352 ) -> Self {
353 let scouter_version = scouter_version.unwrap_or(env!("CARGO_PKG_VERSION").to_string());
354 Self {
355 features,
356 config,
357 scouter_version,
358 }
359 }
360}
361
362#[pymethods]
363impl PsiDriftProfile {
364 pub fn __str__(&self) -> String {
365 ProfileFuncs::__str__(self)
367 }
368
369 pub fn model_dump_json(&self) -> String {
370 ProfileFuncs::__json__(self)
372 }
373 #[allow(clippy::useless_conversion)]
375 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
376 let json_str = serde_json::to_string(&self)?;
377
378 let json_value: Value = serde_json::from_str(&json_str)?;
379
380 let dict = PyDict::new(py);
382
383 json_to_pyobject(py, &json_value, &dict)?;
385
386 Ok(dict.into())
388 }
389
390 #[staticmethod]
391 pub fn from_file(path: PathBuf) -> Result<PsiDriftProfile, ProfileError> {
392 let file = std::fs::read_to_string(&path)?;
393
394 Ok(serde_json::from_str(&file)?)
395 }
396
397 #[staticmethod]
398 pub fn model_validate(data: &Bound<'_, PyDict>) -> PsiDriftProfile {
399 let json_value = pyobject_to_json(data).unwrap();
400
401 let string = serde_json::to_string(&json_value).unwrap();
402 serde_json::from_str(&string).expect("Failed to load drift profile")
403 }
404
405 #[staticmethod]
406 pub fn model_validate_json(json_string: String) -> PsiDriftProfile {
407 serde_json::from_str(&json_string).expect("Failed to load monitor profile")
409 }
410
411 #[pyo3(signature = (path=None))]
413 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
414 Ok(ProfileFuncs::save_to_json(
415 self,
416 path,
417 FileName::PsiDriftProfile.to_str(),
418 )?)
419 }
420
421 #[allow(clippy::too_many_arguments)]
422 #[pyo3(signature = (space=None, name=None, version=None, alert_config=None))]
423 pub fn update_config_args(
424 &mut self,
425 space: Option<String>,
426 name: Option<String>,
427 version: Option<String>,
428 alert_config: Option<PsiAlertConfig>,
429 ) -> Result<(), TypeError> {
430 self.config
431 .update_config_args(space, name, version, alert_config)
432 }
433
434 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
436 Ok(ProfileRequest {
437 space: self.config.space.clone(),
438 profile: self.model_dump_json(),
439 drift_type: self.config.drift_type.clone(),
440 })
441 }
442}
443
444#[pyclass]
445#[derive(Debug, Serialize, Deserialize, Clone)]
446pub struct PsiDriftMap {
447 #[pyo3(get)]
448 pub features: HashMap<String, f64>,
449
450 #[pyo3(get)]
451 pub name: String,
452
453 #[pyo3(get)]
454 pub space: String,
455
456 #[pyo3(get)]
457 pub version: String,
458}
459
460impl PsiDriftMap {
461 pub fn new(space: String, name: String, version: String) -> Self {
462 Self {
463 features: HashMap::new(),
464 name,
465 space,
466 version,
467 }
468 }
469}
470
471#[pymethods]
472#[allow(clippy::new_without_default)]
473impl PsiDriftMap {
474 pub fn __str__(&self) -> String {
475 ProfileFuncs::__str__(self)
477 }
478
479 pub fn model_dump_json(&self) -> String {
480 ProfileFuncs::__json__(self)
482 }
483
484 #[staticmethod]
485 pub fn model_validate_json(json_string: String) -> Result<PsiDriftMap, ProfileError> {
486 Ok(serde_json::from_str(&json_string)?)
488 }
489
490 #[pyo3(signature = (path=None))]
491 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
492 Ok(ProfileFuncs::save_to_json(
493 self,
494 path,
495 FileName::PsiDriftMap.to_str(),
496 )?)
497 }
498}
499
500impl ProfileBaseArgs for PsiDriftProfile {
502 fn get_base_args(&self) -> ProfileArgs {
504 ProfileArgs {
505 name: self.config.name.clone(),
506 space: self.config.space.clone(),
507 version: self.config.version.clone(),
508 schedule: self.config.alert_config.schedule.clone(),
509 scouter_version: self.scouter_version.clone(),
510 drift_type: self.config.drift_type.clone(),
511 }
512 }
513
514 fn to_value(&self) -> Value {
516 serde_json::to_value(self).unwrap()
517 }
518}
519
520#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct DistributionData {
522 pub sample_size: u64,
523 pub bins: BTreeMap<usize, f64>,
524}
525
526#[derive(Debug, Clone, Serialize, Deserialize)]
527pub struct FeatureDistributions {
528 pub distributions: BTreeMap<String, DistributionData>,
529}
530
531impl FeatureDistributions {
532 pub fn is_empty(&self) -> bool {
533 self.distributions.is_empty()
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540
541 #[test]
542 fn test_drift_config() {
543 let mut drift_config = PsiDriftConfig::new(
544 MISSING,
545 MISSING,
546 DEFAULT_VERSION,
547 PsiAlertConfig::default(),
548 None,
549 None,
550 )
551 .unwrap();
552 assert_eq!(drift_config.name, "__missing__");
553 assert_eq!(drift_config.space, "__missing__");
554 assert_eq!(drift_config.version, "0.1.0");
555 assert_eq!(drift_config.alert_config, PsiAlertConfig::default());
556
557 drift_config
559 .update_config_args(None, Some("test".to_string()), None, None)
560 .unwrap();
561
562 assert_eq!(drift_config.name, "test");
563 }
564}