1#![allow(clippy::useless_conversion)]
2use crate::binning::equal_width::EqualWidthBinning;
3use crate::binning::quantile::QuantileBinning;
4use crate::binning::strategy::BinningStrategy;
5use crate::error::{ProfileError, TypeError};
6use crate::psi::alert::PsiAlertConfig;
7use crate::traits::ConfigExt;
8use crate::util::{json_to_pyobject, pyobject_to_json, scouter_version};
9use crate::ProfileRequest;
10use crate::VersionRequest;
11use crate::{
12 DispatchDriftConfig, DriftArgs, DriftType, FeatureMap, FileName, ProfileArgs, ProfileBaseArgs,
13 PyHelperFuncs, DEFAULT_VERSION, MISSING,
14};
15use chrono::Utc;
16use core::fmt::Debug;
17use potato_head::create_uuid7;
18use pyo3::prelude::*;
19use pyo3::types::PyDict;
20use scouter_semver::VersionType;
21use serde::de::{self, MapAccess, Visitor};
22use serde::ser::SerializeStruct;
23use serde::{Deserialize, Deserializer, Serialize, Serializer};
24use serde_json::Value;
25use std::collections::{BTreeMap, HashMap};
26use std::path::PathBuf;
27use tracing::debug;
28
29#[pyclass(eq)]
30#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
31pub enum BinType {
32 Numeric,
33 Category,
34}
35
36#[pyclass]
37#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
38pub struct PsiDriftConfig {
39 #[pyo3(get, set)]
40 pub space: String,
41
42 #[pyo3(get, set)]
43 pub name: String,
44
45 #[pyo3(get, set)]
46 pub version: String,
47
48 #[pyo3(get)]
49 pub uid: String,
50
51 #[pyo3(get, set)]
52 pub alert_config: PsiAlertConfig,
53
54 #[pyo3(get)]
55 #[serde(default)]
56 pub feature_map: FeatureMap,
57
58 #[pyo3(get)]
59 #[serde(default = "default_drift_type")]
60 pub drift_type: DriftType,
61
62 #[pyo3(get, set)]
63 pub categorical_features: Option<Vec<String>>,
64
65 pub binning_strategy: BinningStrategy,
66}
67
68impl ConfigExt for PsiDriftConfig {
69 fn space(&self) -> &str {
70 &self.space
71 }
72
73 fn name(&self) -> &str {
74 &self.name
75 }
76
77 fn version(&self) -> &str {
78 &self.version
79 }
80 fn uid(&self) -> &str {
81 &self.uid
82 }
83}
84
85fn default_drift_type() -> DriftType {
86 DriftType::Psi
87}
88
89impl PsiDriftConfig {
90 pub fn update_feature_map(&mut self, feature_map: FeatureMap) {
91 self.feature_map = feature_map;
92 }
93}
94
95#[pymethods]
96#[allow(clippy::too_many_arguments)]
97impl PsiDriftConfig {
98 #[new]
99 #[pyo3(signature = (space=MISSING, name=MISSING, version=DEFAULT_VERSION, alert_config=PsiAlertConfig::default(), config_path=None, categorical_features=None, binning_strategy=None))]
100 pub fn new(
101 space: &str,
102 name: &str,
103 version: &str,
104 alert_config: PsiAlertConfig,
105 config_path: Option<PathBuf>,
106 categorical_features: Option<Vec<String>>,
107 binning_strategy: Option<&Bound<'_, PyAny>>,
108 ) -> Result<Self, ProfileError> {
109 if let Some(config_path) = config_path {
110 let config = PsiDriftConfig::load_from_json_file(config_path);
111 return config;
112 }
113
114 let binning_strategy = match binning_strategy {
115 None => BinningStrategy::default(),
116 Some(strategy) => {
117 if strategy.is_instance_of::<QuantileBinning>() {
118 BinningStrategy::QuantileBinning(strategy.extract()?)
119 } else if strategy.is_instance_of::<EqualWidthBinning>() {
120 BinningStrategy::EqualWidthBinning(strategy.extract()?)
121 } else {
122 return Err(ProfileError::InvalidBinningStrategyError);
123 }
124 }
125 };
126
127 if name == MISSING || space == MISSING {
128 debug!("Name and space were not provided. Defaulting to __missing__");
129 }
130
131 Ok(Self {
132 name: name.to_string(),
133 space: space.to_string(),
134 version: version.to_string(),
135 uid: create_uuid7(),
136 alert_config,
137 categorical_features,
138 feature_map: FeatureMap::default(),
139 drift_type: DriftType::Psi,
140 binning_strategy,
141 })
142 }
143
144 #[staticmethod]
145 pub fn load_from_json_file(path: PathBuf) -> Result<PsiDriftConfig, ProfileError> {
146 let file = std::fs::read_to_string(&path)?;
149
150 Ok(serde_json::from_str(&file)?)
151 }
152
153 pub fn __str__(&self) -> String {
154 PyHelperFuncs::__str__(self)
156 }
157
158 pub fn model_dump_json(&self) -> String {
159 PyHelperFuncs::__json__(self)
161 }
162
163 #[allow(clippy::too_many_arguments)]
164 #[pyo3(signature = (space=None, name=None, version=None, uid=None, alert_config=None, categorical_features=None, binning_strategy=None))]
165 pub fn update_config_args(
166 &mut self,
167 space: Option<String>,
168 name: Option<String>,
169 version: Option<String>,
170 uid: Option<String>,
171 alert_config: Option<PsiAlertConfig>,
172 categorical_features: Option<Vec<String>>,
173 binning_strategy: Option<&Bound<'_, PyAny>>,
174 ) -> Result<(), TypeError> {
175 if name.is_some() {
176 self.name = name.ok_or(TypeError::MissingNameError)?;
177 }
178
179 if space.is_some() {
180 self.space = space.ok_or(TypeError::MissingSpaceError)?;
181 }
182
183 if version.is_some() {
184 self.version = version.ok_or(TypeError::MissingVersionError)?;
185 }
186
187 if alert_config.is_some() {
188 self.alert_config = alert_config.ok_or(TypeError::MissingAlertConfigError)?;
189 }
190
191 if uid.is_some() {
192 self.uid = uid.ok_or(TypeError::MissingUidError)?;
193 }
194
195 if categorical_features.is_some() {
196 self.categorical_features = categorical_features;
197 }
198
199 if let Some(binning_strategy) = binning_strategy {
200 if binning_strategy.is_instance_of::<QuantileBinning>() {
201 self.binning_strategy =
202 BinningStrategy::QuantileBinning(binning_strategy.extract()?);
203 } else if binning_strategy.is_instance_of::<EqualWidthBinning>() {
204 self.binning_strategy =
205 BinningStrategy::EqualWidthBinning(binning_strategy.extract()?);
206 } else {
207 return Err(TypeError::InvalidBinningStrategyError);
208 }
209 }
210
211 Ok(())
212 }
213
214 #[getter]
215 pub fn binning_strategy<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
216 self.binning_strategy.strategy(py)
217 }
218
219 #[setter]
220 pub fn set_binning_strategy(&mut self, strategy: &Bound<'_, PyAny>) -> PyResult<()> {
221 if strategy.is_instance_of::<QuantileBinning>() {
222 self.binning_strategy = BinningStrategy::QuantileBinning(strategy.extract()?);
223 } else if strategy.is_instance_of::<EqualWidthBinning>() {
224 self.binning_strategy = BinningStrategy::EqualWidthBinning(strategy.extract()?);
225 } else {
226 return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
227 "Invalid binning strategy type",
228 ));
229 }
230 Ok(())
231 }
232}
233
234impl Default for PsiDriftConfig {
235 fn default() -> Self {
236 PsiDriftConfig {
237 name: "__missing__".to_string(),
238 space: "__missing__".to_string(),
239 version: DEFAULT_VERSION.to_string(),
240 uid: MISSING.to_string(),
241 feature_map: FeatureMap::default(),
242 alert_config: PsiAlertConfig::default(),
243 drift_type: DriftType::Psi,
244 categorical_features: None,
245 binning_strategy: BinningStrategy::default(),
246 }
247 }
248}
249impl DispatchDriftConfig for PsiDriftConfig {
252 fn get_drift_args(&self) -> DriftArgs {
253 DriftArgs {
254 name: self.name.clone(),
255 space: self.space.clone(),
256 version: self.version.clone(),
257 dispatch_config: self.alert_config.dispatch_config.clone(),
258 }
259 }
260}
261
262#[pyclass]
263#[derive(Debug, Clone, PartialEq)]
264pub struct Bin {
265 #[pyo3(get)]
266 pub id: i32,
267
268 #[pyo3(get)]
269 pub lower_limit: Option<f64>,
270
271 #[pyo3(get)]
272 pub upper_limit: Option<f64>,
273
274 #[pyo3(get)]
275 pub proportion: f64,
276}
277
278impl Serialize for Bin {
279 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
280 where
281 S: Serializer,
282 {
283 let mut state = serializer.serialize_struct("Bin", 4)?;
284 state.serialize_field("id", &self.id)?;
285
286 state.serialize_field(
287 "lower_limit",
288 &self.lower_limit.map(|v| {
289 if v.is_infinite() {
290 serde_json::Value::String(if v.is_sign_positive() {
291 "inf".to_string()
292 } else {
293 "-inf".to_string()
294 })
295 } else {
296 serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
297 }
298 }),
299 )?;
300 state.serialize_field(
301 "upper_limit",
302 &self.upper_limit.map(|v| {
303 if v.is_infinite() {
304 serde_json::Value::String(if v.is_sign_positive() {
305 "inf".to_string()
306 } else {
307 "-inf".to_string()
308 })
309 } else {
310 serde_json::Value::Number(serde_json::Number::from_f64(v).unwrap())
311 }
312 }),
313 )?;
314 state.serialize_field("proportion", &self.proportion)?;
315 state.end()
316 }
317}
318
319impl<'de> Deserialize<'de> for Bin {
320 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
321 where
322 D: Deserializer<'de>,
323 {
324 #[derive(Deserialize)]
325 #[serde(untagged)]
326 enum NumberOrString {
327 Number(f64),
328 String(String),
329 }
330
331 #[derive(Deserialize)]
332 #[serde(field_identifier, rename_all = "snake_case")]
333 enum Field {
334 Id,
335 LowerLimit,
336 UpperLimit,
337 Proportion,
338 }
339
340 struct BinVisitor;
341
342 impl<'de> Visitor<'de> for BinVisitor {
343 type Value = Bin;
344
345 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
346 formatter.write_str("struct Bin")
347 }
348
349 fn visit_map<V>(self, mut map: V) -> Result<Bin, V::Error>
350 where
351 V: MapAccess<'de>,
352 {
353 let mut id = None;
354 let mut lower_limit = None;
355 let mut upper_limit = None;
356 let mut proportion = None;
357
358 while let Some(key) = map.next_key()? {
359 match key {
360 Field::Id => {
361 id = Some(map.next_value()?);
362 }
363 Field::LowerLimit => {
364 let val: Option<NumberOrString> = map.next_value()?;
365 lower_limit = Some(val.map(|v| match v {
366 NumberOrString::String(s) => match s.as_str() {
367 "inf" => f64::INFINITY,
368 "-inf" => f64::NEG_INFINITY,
369 _ => s.parse().unwrap(),
370 },
371 NumberOrString::Number(n) => n,
372 }));
373 }
374 Field::UpperLimit => {
375 let val: Option<NumberOrString> = map.next_value()?;
376 upper_limit = Some(val.map(|v| match v {
377 NumberOrString::String(s) => match s.as_str() {
378 "inf" => f64::INFINITY,
379 "-inf" => f64::NEG_INFINITY,
380 _ => s.parse().unwrap(),
381 },
382 NumberOrString::Number(n) => n,
383 }));
384 }
385 Field::Proportion => {
386 proportion = Some(map.next_value()?);
387 }
388 }
389 }
390
391 Ok(Bin {
392 id: id.ok_or_else(|| de::Error::missing_field("id"))?,
393 lower_limit: lower_limit
394 .ok_or_else(|| de::Error::missing_field("lower_limit"))?,
395 upper_limit: upper_limit
396 .ok_or_else(|| de::Error::missing_field("upper_limit"))?,
397 proportion: proportion.ok_or_else(|| de::Error::missing_field("proportion"))?,
398 })
399 }
400 }
401
402 const FIELDS: &[&str] = &["id", "lower_limit", "upper_limit", "proportion"];
403 deserializer.deserialize_struct("Bin", FIELDS, BinVisitor)
404 }
405}
406
407#[pyclass]
408#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
409pub struct PsiFeatureDriftProfile {
410 #[pyo3(get)]
411 pub id: String,
412
413 #[pyo3(get)]
414 pub bins: Vec<Bin>,
415
416 #[pyo3(get)]
417 pub timestamp: chrono::DateTime<Utc>,
418
419 #[pyo3(get)]
420 pub bin_type: BinType,
421}
422
423#[pyclass]
424#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
425pub struct PsiDriftProfile {
426 #[pyo3(get)]
427 pub features: HashMap<String, PsiFeatureDriftProfile>,
428
429 #[pyo3(get)]
430 pub config: PsiDriftConfig,
431
432 #[pyo3(get)]
433 pub scouter_version: String,
434}
435
436impl PsiDriftProfile {
437 pub fn new(features: HashMap<String, PsiFeatureDriftProfile>, config: PsiDriftConfig) -> Self {
438 Self {
439 features,
440 config,
441 scouter_version: scouter_version(),
442 }
443 }
444}
445
446#[pymethods]
447impl PsiDriftProfile {
448 pub fn __str__(&self) -> String {
449 PyHelperFuncs::__str__(self)
451 }
452
453 #[getter]
454 pub fn uid(&self) -> String {
455 self.config.uid.clone()
456 }
457
458 #[setter]
459 pub fn set_uid(&mut self, uid: String) {
460 self.config.uid = uid;
461 }
462
463 #[getter]
464 pub fn drift_type(&self) -> DriftType {
465 self.config.drift_type.clone()
466 }
467 pub fn model_dump_json(&self) -> String {
468 PyHelperFuncs::__json__(self)
470 }
471 #[allow(clippy::useless_conversion)]
473 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, ProfileError> {
474 let json_str = serde_json::to_string(&self)?;
475
476 let json_value: Value = serde_json::from_str(&json_str)?;
477
478 let dict = PyDict::new(py);
480
481 json_to_pyobject(py, &json_value, &dict)?;
483
484 Ok(dict.into())
486 }
487
488 #[staticmethod]
489 pub fn from_file(path: PathBuf) -> Result<PsiDriftProfile, ProfileError> {
490 let file = std::fs::read_to_string(&path)?;
491
492 Ok(serde_json::from_str(&file)?)
493 }
494
495 #[staticmethod]
496 pub fn model_validate(data: &Bound<'_, PyDict>) -> PsiDriftProfile {
497 let json_value = pyobject_to_json(data).unwrap();
498
499 let string = serde_json::to_string(&json_value).unwrap();
500 serde_json::from_str(&string).expect("Failed to load drift profile")
501 }
502
503 #[staticmethod]
504 pub fn model_validate_json(json_string: String) -> PsiDriftProfile {
505 serde_json::from_str(&json_string).expect("Failed to load monitor profile")
507 }
508
509 #[pyo3(signature = (path=None))]
511 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
512 Ok(PyHelperFuncs::save_to_json(
513 self,
514 path,
515 FileName::PsiDriftProfile.to_str(),
516 )?)
517 }
518
519 #[allow(clippy::too_many_arguments)]
520 #[pyo3(signature = (space=None, name=None, version=None, uid=None,alert_config=None, categorical_features=None, binning_strategy=None))]
521 pub fn update_config_args(
522 &mut self,
523 space: Option<String>,
524 name: Option<String>,
525 version: Option<String>,
526 uid: Option<String>,
527 alert_config: Option<PsiAlertConfig>,
528 categorical_features: Option<Vec<String>>,
529 binning_strategy: Option<&Bound<'_, PyAny>>,
530 ) -> Result<(), TypeError> {
531 self.config.update_config_args(
532 space,
533 name,
534 version,
535 uid,
536 alert_config,
537 categorical_features,
538 binning_strategy,
539 )
540 }
541
542 pub fn create_profile_request(&self) -> Result<ProfileRequest, TypeError> {
544 let version: Option<String> = if self.config.version == DEFAULT_VERSION {
545 None
546 } else {
547 Some(self.config.version.clone())
548 };
549
550 Ok(ProfileRequest {
551 space: self.config.space.clone(),
552 profile: self.model_dump_json(),
553 drift_type: self.config.drift_type.clone(),
554 version_request: Some(VersionRequest {
555 version,
556 version_type: VersionType::Minor,
557 pre_tag: None,
558 build_tag: None,
559 }),
560 active: false,
561 deactivate_others: false,
562 })
563 }
564}
565
566#[pyclass]
567#[derive(Debug, Serialize, Deserialize, Clone)]
568pub struct PsiDriftMap {
569 #[pyo3(get)]
570 pub features: HashMap<String, f64>,
571
572 #[pyo3(get)]
573 pub name: String,
574
575 #[pyo3(get)]
576 pub space: String,
577
578 #[pyo3(get)]
579 pub version: String,
580}
581
582impl PsiDriftMap {
583 pub fn new(space: String, name: String, version: String) -> Self {
584 Self {
585 features: HashMap::new(),
586 name,
587 space,
588 version,
589 }
590 }
591}
592
593#[pymethods]
594#[allow(clippy::new_without_default)]
595impl PsiDriftMap {
596 pub fn __str__(&self) -> String {
597 PyHelperFuncs::__str__(self)
599 }
600
601 pub fn model_dump_json(&self) -> String {
602 PyHelperFuncs::__json__(self)
604 }
605
606 #[staticmethod]
607 pub fn model_validate_json(json_string: String) -> Result<PsiDriftMap, ProfileError> {
608 Ok(serde_json::from_str(&json_string)?)
610 }
611
612 #[pyo3(signature = (path=None))]
613 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
614 Ok(PyHelperFuncs::save_to_json(
615 self,
616 path,
617 FileName::PsiDriftMap.to_str(),
618 )?)
619 }
620}
621
622impl ProfileBaseArgs for PsiDriftProfile {
624 type Config = PsiDriftConfig;
625
626 fn config(&self) -> &Self::Config {
627 &self.config
628 }
629 fn get_base_args(&self) -> ProfileArgs {
631 ProfileArgs {
632 name: self.config.name.clone(),
633 space: self.config.space.clone(),
634 version: Some(self.config.version.clone()),
635 schedule: self.config.alert_config.schedule.clone(),
636 scouter_version: self.scouter_version.clone(),
637 drift_type: self.config.drift_type.clone(),
638 }
639 }
640
641 fn to_value(&self) -> Value {
643 serde_json::to_value(self).unwrap()
644 }
645}
646
647#[derive(Debug, Clone, Serialize, Deserialize)]
648pub struct DistributionData {
649 pub sample_size: u64,
650 pub bins: BTreeMap<i32, f64>,
651}
652
653#[derive(Debug, Clone, Serialize, Deserialize)]
654pub struct FeatureDistributions {
655 pub distributions: BTreeMap<String, DistributionData>,
656}
657
658impl FeatureDistributions {
659 pub fn is_empty(&self) -> bool {
660 self.distributions.is_empty()
661 }
662}
663
664#[derive(Debug)]
665pub struct FeatureDistributionRow {
666 pub feature: String,
667 pub distribution: DistributionData,
668}
669
670#[cfg(feature = "server")]
671impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for FeatureDistributionRow {
672 fn from_row(row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
673 use sqlx::Row;
674
675 let feature: String = row.try_get("feature")?;
676 let sample_size: i64 = row.try_get("sample_size")?;
677 let bins_json: serde_json::Value = row.try_get("bins")?;
678 let bins: BTreeMap<i32, f64> =
679 serde_json::from_value(bins_json).map_err(|e| sqlx::Error::Decode(Box::new(e)))?;
680
681 Ok(FeatureDistributionRow {
682 feature,
683 distribution: DistributionData {
684 sample_size: sample_size as u64,
685 bins,
686 },
687 })
688 }
689}
690
691impl FeatureDistributions {
692 pub fn from_rows(rows: Vec<FeatureDistributionRow>) -> Self {
694 let distributions = rows
695 .into_iter()
696 .map(|row| (row.feature, row.distribution))
697 .collect();
698
699 FeatureDistributions { distributions }
700 }
701}