1#![allow(clippy::useless_conversion)]
2use crate::binning::equal_width::EqualWidthBinning;
3use crate::binning::quantile::QuantileBinning;
4use crate::binning::strategy::BinningStrategy;
5use crate::error::{ProfileError, TypeError};
6use crate::psi::alert::PsiAlertConfig;
7use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version, ConfigExt};
8use crate::ProfileRequest;
9use crate::VersionRequest;
10use crate::{
11 DispatchDriftConfig, DriftArgs, DriftType, FeatureMap, FileName, ProfileArgs, ProfileBaseArgs,
12 PyHelperFuncs, DEFAULT_VERSION, MISSING,
13};
14use chrono::Utc;
15use core::fmt::Debug;
16use potato_head::create_uuid7;
17use pyo3::prelude::*;
18use pyo3::types::PyDict;
19use scouter_semver::VersionType;
20use serde::de::{self, MapAccess, Visitor};
21use serde::ser::SerializeStruct;
22use serde::{Deserialize, Deserializer, Serialize, Serializer};
23use serde_json::Value;
24use std::collections::{BTreeMap, HashMap};
25use std::path::PathBuf;
26use tracing::debug;
27
28#[pyclass(eq)]
29#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
30pub enum BinType {
31 Numeric,
32 Category,
33}
34
35#[pyclass]
36#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
37pub struct PsiDriftConfig {
38 #[pyo3(get, set)]
39 pub space: String,
40
41 #[pyo3(get, set)]
42 pub name: String,
43
44 #[pyo3(get, set)]
45 pub version: String,
46
47 #[pyo3(get)]
48 pub uid: String,
49
50 #[pyo3(get, set)]
51 pub alert_config: PsiAlertConfig,
52
53 #[pyo3(get)]
54 #[serde(default)]
55 pub feature_map: FeatureMap,
56
57 #[pyo3(get)]
58 #[serde(default = "default_drift_type")]
59 pub drift_type: DriftType,
60
61 #[pyo3(get, set)]
62 pub categorical_features: Option<Vec<String>>,
63
64 pub binning_strategy: BinningStrategy,
65}
66
67impl ConfigExt for PsiDriftConfig {
68 fn space(&self) -> &str {
69 &self.space
70 }
71
72 fn name(&self) -> &str {
73 &self.name
74 }
75
76 fn version(&self) -> &str {
77 &self.version
78 }
79}
80
81fn default_drift_type() -> DriftType {
82 DriftType::Psi
83}
84
85impl PsiDriftConfig {
86 pub fn update_feature_map(&mut self, feature_map: FeatureMap) {
87 self.feature_map = feature_map;
88 }
89}
90
91#[pymethods]
92#[allow(clippy::too_many_arguments)]
93impl PsiDriftConfig {
94 #[new]
95 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, alert_config=PsiAlertConfig::default(), config_path=None, categorical_features=None, binning_strategy=None))]
96 pub fn new(
97 space: &str,
98 name: &str,
99 version: &str,
100 alert_config: PsiAlertConfig,
101 config_path: Option<PathBuf>,
102 categorical_features: Option<Vec<String>>,
103 binning_strategy: Option<&Bound<'_, PyAny>>,
104 ) -> Result<Self, ProfileError> {
105 if let Some(config_path) = config_path {
106 let config = PsiDriftConfig::load_from_json_file(config_path);
107 return config;
108 }
109
110 let binning_strategy = match binning_strategy {
111 None => BinningStrategy::default(),
112 Some(strategy) => {
113 if strategy.is_instance_of::<QuantileBinning>() {
114 BinningStrategy::QuantileBinning(strategy.extract()?)
115 } else if strategy.is_instance_of::<EqualWidthBinning>() {
116 BinningStrategy::EqualWidthBinning(strategy.extract()?)
117 } else {
118 return Err(ProfileError::InvalidBinningStrategyError);
119 }
120 }
121 };
122
123 if name == MISSING || space == MISSING {
124 debug!("Name and space were not provided. Defaulting to __missing__");
125 }
126
127 Ok(Self {
128 name: name.to_string(),
129 space: space.to_string(),
130 version: version.to_string(),
131 uid: create_uuid7(),
132 alert_config,
133 categorical_features,
134 feature_map: FeatureMap::default(),
135 drift_type: DriftType::Psi,
136 binning_strategy,
137 })
138 }
139
140 #[staticmethod]
141 pub fn load_from_json_file(path: PathBuf) -> Result<PsiDriftConfig, ProfileError> {
142 let file = std::fs::read_to_string(&path)?;
145
146 Ok(serde_json::from_str(&file)?)
147 }
148
149 pub fn __str__(&self) -> String {
150 PyHelperFuncs::__str__(self)
152 }
153
154 pub fn model_dump_json(&self) -> String {
155 PyHelperFuncs::__json__(self)
157 }
158
159 #[allow(clippy::too_many_arguments)]
160 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None, categorical_features=None, binning_strategy=None))]
161 pub fn update_config_args(
162 &mut self,
163 space: Option<String>,
164 name: Option<String>,
165 version: Option<String>,
166 uid: Option<String>,
167 alert_config: Option<PsiAlertConfig>,
168 categorical_features: Option<Vec<String>>,
169 binning_strategy: Option<&Bound<'_, PyAny>>,
170 ) -> Result<(), TypeError> {
171 if name.is_some() {
172 self.name = name.ok_or(TypeError::MissingNameError)?;
173 }
174
175 if space.is_some() {
176 self.space = space.ok_or(TypeError::MissingSpaceError)?;
177 }
178
179 if version.is_some() {
180 self.version = version.ok_or(TypeError::MissingVersionError)?;
181 }
182
183 if alert_config.is_some() {
184 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
185 }
186
187 if uid.is_some() {
188 self.uid = uid.ok_or(TypeError::MissingUidError)?;
189 }
190
191 if categorical_features.is_some() {
192 self.categorical_features = categorical_features;
193 }
194
195 if let Some(binning_strategy) = binning_strategy {
196 if binning_strategy.is_instance_of::<QuantileBinning>() {
197 self.binning_strategy =
198 BinningStrategy::QuantileBinning(binning_strategy.extract()?);
199 } else if binning_strategy.is_instance_of::<EqualWidthBinning>() {
200 self.binning_strategy =
201 BinningStrategy::EqualWidthBinning(binning_strategy.extract()?);
202 } else {
203 return Err(TypeError::InvalidBinningStrategyError);
204 }
205 }
206
207 Ok(())
208 }
209
210 #[getter]
211 pub fn binning_strategy<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
212 self.binning_strategy.strategy(py)
213 }
214
215 #[setter]
216 pub fn set_binning_strategy(&mut self, strategy: &Bound<'_, PyAny>) -> PyResult<()> {
217 if strategy.is_instance_of::<QuantileBinning>() {
218 self.binning_strategy = BinningStrategy::QuantileBinning(strategy.extract()?);
219 } else if strategy.is_instance_of::<EqualWidthBinning>() {
220 self.binning_strategy = BinningStrategy::EqualWidthBinning(strategy.extract()?);
221 } else {
222 return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
223 "Invalid binning strategy type",
224 ));
225 }
226 Ok(())
227 }
228}
229
230impl Default for PsiDriftConfig {
231 fn default() -> Self {
232 PsiDriftConfig {
233 name: "__missing__".to_string(),
234 space: "__missing__".to_string(),
235 version: DEFAULT_VERSION.to_string(),
236 uid: MISSING.to_string(),
237 feature_map: FeatureMap::default(),
238 alert_config: PsiAlertConfig::default(),
239 drift_type: DriftType::Psi,
240 categorical_features: None,
241 binning_strategy: BinningStrategy::default(),
242 }
243 }
244}
245impl DispatchDriftConfig for PsiDriftConfig {
248 fn get_drift_args(&self) -> DriftArgs {
249 DriftArgs {
250 name: self.name.clone(),
251 space: self.space.clone(),
252 version: self.version.clone(),
253 dispatch_config: self.alert_config.dispatch_config.clone(),
254 }
255 }
256}
257
258#[pyclass]
259#[derive(Debug, Clone, PartialEq)]
260pub struct Bin {
261 #[pyo3(get)]
262 pub id: i32,
263
264 #[pyo3(get)]
265 pub lower_limit: Option<f64>,
266
267 #[pyo3(get)]
268 pub upper_limit: Option<f64>,
269
270 #[pyo3(get)]
271 pub proportion: f64,
272}
273
274impl Serialize for Bin {
275 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
276 where
277 S: Serializer,
278 {
279 let mut state = serializer.serialize_struct("Bin", 4)?;
280 state.serialize_field("id", &self.id)?;
281
282 state.serialize_field(
283 "lower_limit",
284 &self.lower_limit.map(|v| {
285 if v.is_infinite() {
286 serde_json::Value::String(if v.is_sign_positive() {
287 "inf".to_string()
288 } else {
289 "-inf".to_string()
290 })
291 } else {
292 serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
293 }
294 }),
295 )?;
296 state.serialize_field(
297 "upper_limit",
298 &self.upper_limit.map(|v| {
299 if v.is_infinite() {
300 serde_json::Value::String(if v.is_sign_positive() {
301 "inf".to_string()
302 } else {
303 "-inf".to_string()
304 })
305 } else {
306 serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
307 }
308 }),
309 )?;
310 state.serialize_field("proportion", &self.proportion)?;
311 state.end()
312 }
313}
314
315impl<'de> Deserialize<'de> for Bin {
316 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
317 where
318 D: Deserializer<'de>,
319 {
320 #[derive(Deserialize)]
321 #[serde(untagged)]
322 enum NumberOrString {
323 Number(f64),
324 String(String),
325 }
326
327 #[derive(Deserialize)]
328 #[serde(field_identifier, rename_all = "snake_case")]
329 enum Field {
330 Id,
331 LowerLimit,
332 UpperLimit,
333 Proportion,
334 }
335
336 struct BinVisitor;
337
338 impl<'de> Visitor<'de> for BinVisitor {
339 type Value = Bin;
340
341 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
342 formatter.write_str("struct Bin")
343 }
344
345 fn visit_map<V>(self, mut map: V) -> Result<Bin, V::Error>
346 where
347 V: MapAccess<'de>,
348 {
349 let mut id = None;
350 let mut lower_limit = None;
351 let mut upper_limit = None;
352 let mut proportion = None;
353
354 while let Some(key) = map.next_key()? {
355 match key {
356 Field::Id => {
357 id = Some(map.next_value()?);
358 }
359 Field::LowerLimit => {
360 let val: Option<NumberOrString> = map.next_value()?;
361 lower_limit = Some(val.map(|v| match v {
362 NumberOrString::String(s) => match s.as_str() {
363 "inf" => f64::INFINITY,
364 "-inf" => f64::NEG_INFINITY,
365 _ => s.parse().unwrap(),
366 },
367 NumberOrString::Number(n) => n,
368 }));
369 }
370 Field::UpperLimit => {
371 let val: Option<NumberOrString> = map.next_value()?;
372 upper_limit = Some(val.map(|v| match v {
373 NumberOrString::String(s) => match s.as_str() {
374 "inf" => f64::INFINITY,
375 "-inf" => f64::NEG_INFINITY,
376 _ => s.parse().unwrap(),
377 },
378 NumberOrString::Number(n) => n,
379 }));
380 }
381 Field::Proportion => {
382 proportion = Some(map.next_value()?);
383 }
384 }
385 }
386
387 Ok(Bin {
388 id: id.ok_or_else(|| de::Error::missing_field("id"))?,
389 lower_limit: lower_limit
390 .ok_or_else(|| de::Error::missing_field("lower_limit"))?,
391 upper_limit: upper_limit
392 .ok_or_else(|| de::Error::missing_field("upper_limit"))?,
393 proportion: proportion.ok_or_else(|| de::Error::missing_field("proportion"))?,
394 })
395 }
396 }
397
398 const FIELDS: &[&str] = &["id", "lower_limit", "upper_limit", "proportion"];
399 deserializer.deserialize_struct("Bin", FIELDS, BinVisitor)
400 }
401}
402
403#[pyclass]
404#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
405pub struct PsiFeatureDriftProfile {
406 #[pyo3(get)]
407 pub id: String,
408
409 #[pyo3(get)]
410 pub bins: Vec<Bin>,
411
412 #[pyo3(get)]
413 pub timestamp: chrono::DateTime<Utc>,
414
415 #[pyo3(get)]
416 pub bin_type: BinType,
417}
418
419#[pyclass]
420#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
421pub struct PsiDriftProfile {
422 #[pyo3(get)]
423 pub features: HashMap<String, PsiFeatureDriftProfile>,
424
425 #[pyo3(get)]
426 pub config: PsiDriftConfig,
427
428 #[pyo3(get)]
429 pub scouter_version: String,
430}
431
432impl PsiDriftProfile {
433 pub fn new(features: HashMap<String, PsiFeatureDriftProfile>, config: PsiDriftConfig) -> Self {
434 Self {
435 features,
436 config,
437 scouter_version: scouter_version(),
438 }
439 }
440}
441
442#[pymethods]
443impl PsiDriftProfile {
444 pub fn __str__(&self) -> String {
445 PyHelperFuncs::__str__(self)
447 }
448
449 #[getter]
450 pub fn uid(&self) -> String {
451 self.config.uid.clone()
452 }
453
454 #[setter]
455 pub fn set_uid(&mut self, uid: String) {
456 self.config.uid = uid;
457 }
458
459 pub fn model_dump_json(&self) -> String {
460 PyHelperFuncs::__json__(self)
462 }
463 #[allow(clippy::useless_conversion)]
465 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
466 let json_str = serde_json::to_string(&self)?;
467
468 let json_value: Value = serde_json::from_str(&json_str)?;
469
470 let dict = PyDict::new(py);
472
473 json_to_pyobject(py, &json_value, &dict)?;
475
476 Ok(dict.into())
478 }
479
480 #[staticmethod]
481 pub fn from_file(path: PathBuf) -> Result<PsiDriftProfile, ProfileError> {
482 let file = std::fs::read_to_string(&path)?;
483
484 Ok(serde_json::from_str(&file)?)
485 }
486
487 #[staticmethod]
488 pub fn model_validate(data: &Bound<'_, PyDict>) -> PsiDriftProfile {
489 let json_value = pyobject_to_json(data).unwrap();
490
491 let string = serde_json::to_string(&json_value).unwrap();
492 serde_json::from_str(&string).expect("Failed to load drift profile")
493 }
494
495 #[staticmethod]
496 pub fn model_validate_json(json_string: String) -> PsiDriftProfile {
497 serde_json::from_str(&json_string).expect("Failed to load monitor profile")
499 }
500
501 #[pyo3(signature = (path=None))]
503 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
504 Ok(PyHelperFuncs::save_to_json(
505 self,
506 path,
507 FileName::PsiDriftProfile.to_str(),
508 )?)
509 }
510
511 #[allow(clippy::too_many_arguments)]
512 #[pyo3(signature = (space=None, name=None, version=None, uid=None,alert_config=None, categorical_features=None, binning_strategy=None))]
513 pub fn update_config_args(
514 &mut self,
515 space: Option<String>,
516 name: Option<String>,
517 version: Option<String>,
518 uid: Option<String>,
519 alert_config: Option<PsiAlertConfig>,
520 categorical_features: Option<Vec<String>>,
521 binning_strategy: Option<&Bound<'_, PyAny>>,
522 ) -> Result<(), TypeError> {
523 self.config.update_config_args(
524 space,
525 name,
526 version,
527 uid,
528 alert_config,
529 categorical_features,
530 binning_strategy,
531 )
532 }
533
534 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
536 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
537 None
538 } else {
539 Some(self.config.version.clone())
540 };
541
542 Ok(ProfileRequest {
543 space: self.config.space.clone(),
544 profile: self.model_dump_json(),
545 drift_type: self.config.drift_type.clone(),
546 version_request: Some(VersionRequest {
547 version,
548 version_type: VersionType::Minor,
549 pre_tag: None,
550 build_tag: None,
551 }),
552 active: false,
553 deactivate_others: false,
554 })
555 }
556}
557
558#[pyclass]
559#[derive(Debug, Serialize, Deserialize, Clone)]
560pub struct PsiDriftMap {
561 #[pyo3(get)]
562 pub features: HashMap<String, f64>,
563
564 #[pyo3(get)]
565 pub name: String,
566
567 #[pyo3(get)]
568 pub space: String,
569
570 #[pyo3(get)]
571 pub version: String,
572}
573
574impl PsiDriftMap {
575 pub fn new(space: String, name: String, version: String) -> Self {
576 Self {
577 features: HashMap::new(),
578 name,
579 space,
580 version,
581 }
582 }
583}
584
585#[pymethods]
586#[allow(clippy::new_without_default)]
587impl PsiDriftMap {
588 pub fn __str__(&self) -> String {
589 PyHelperFuncs::__str__(self)
591 }
592
593 pub fn model_dump_json(&self) -> String {
594 PyHelperFuncs::__json__(self)
596 }
597
598 #[staticmethod]
599 pub fn model_validate_json(json_string: String) -> Result<PsiDriftMap, ProfileError> {
600 Ok(serde_json::from_str(&json_string)?)
602 }
603
604 #[pyo3(signature = (path=None))]
605 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
606 Ok(PyHelperFuncs::save_to_json(
607 self,
608 path,
609 FileName::PsiDriftMap.to_str(),
610 )?)
611 }
612}
613
614impl ProfileBaseArgs for PsiDriftProfile {
616 type Config = PsiDriftConfig;
617
618 fn config(&self) -> &Self::Config {
619 &self.config
620 }
621 fn get_base_args(&self) -> ProfileArgs {
623 ProfileArgs {
624 name: self.config.name.clone(),
625 space: self.config.space.clone(),
626 version: Some(self.config.version.clone()),
627 schedule: self.config.alert_config.schedule.clone(),
628 scouter_version: self.scouter_version.clone(),
629 drift_type: self.config.drift_type.clone(),
630 }
631 }
632
633 fn to_value(&self) -> Value {
635 serde_json::to_value(self).unwrap()
636 }
637}
638
639#[derive(Debug, Clone, Serialize, Deserialize)]
640pub struct DistributionData {
641 pub sample_size: u64,
642 pub bins: BTreeMap<i32, f64>,
643}
644
645#[derive(Debug, Clone, Serialize, Deserialize)]
646pub struct FeatureDistributions {
647 pub distributions: BTreeMap<String, DistributionData>,
648}
649
650impl FeatureDistributions {
651 pub fn is_empty(&self) -> bool {
652 self.distributions.is_empty()
653 }
654}
655
656#[derive(Debug)]
657pub struct FeatureDistributionRow {
658 pub feature: String,
659 pub distribution: DistributionData,
660}
661
662#[cfg(feature = "server")]
663impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for FeatureDistributionRow {
664 fn from_row(row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
665 use sqlx::Row;
666
667 let feature: String = row.try_get("feature")?;
668 let sample_size: i64 = row.try_get("sample_size")?;
669 let bins_json: serde_json::Value = row.try_get("bins")?;
670 let bins: BTreeMap<i32, f64> =
671 serde_json::from_value(bins_json).map_err(|e| sqlx::Error::Decode(Box::new(e)))?;
672
673 Ok(FeatureDistributionRow {
674 feature,
675 distribution: DistributionData {
676 sample_size: sample_size as u64,
677 bins,
678 },
679 })
680 }
681}
682
683impl FeatureDistributions {
684 pub fn from_rows(rows: Vec<FeatureDistributionRow>) -> Self {
686 let distributions = rows
687 .into_iter()
688 .map(|row| (row.feature, row.distribution))
689 .collect();
690
691 FeatureDistributions { distributions }
692 }
693}