Skip to main content

scouter_drift/spc/
types.rs

1#![allow(clippy::useless_conversion)]
2use crate::error::DriftError;
3use core::fmt::Debug;
4use ndarray::Array;
5use ndarray::Array2;
6use numpy::{IntoPyArray, PyArray2};
7use pyo3::prelude::*;
8use scouter_types::error::UtilError;
9
10use scouter_types::{FileName, PyHelperFuncs};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::path::PathBuf;
14
15/// Python class for a feature drift
16///
17/// # Arguments
18///
19/// * `samples` - A vector of samples
20/// * `drift` - A vector of drift values
21///
22#[pyclass]
23#[derive(Debug, Serialize, Deserialize, Clone)]
24pub struct SpcFeatureDrift {
25    #[pyo3(get)]
26    pub samples: Vec<f64>,
27
28    #[pyo3(get)]
29    pub drift: Vec<f64>,
30}
31
32impl SpcFeatureDrift {
33    pub fn __str__(&self) -> String {
34        // serialize the struct to a string
35        serde_json::to_string_pretty(&self).unwrap()
36    }
37}
38
39/// Python class for a Drift map of features with calculated drift
40///
41/// # Arguments
42///
43/// * `features` - A hashmap of feature names and their drift
44///
45#[pyclass]
46#[derive(Debug, Serialize, Deserialize, Clone)]
47pub struct SpcDriftMap {
48    #[pyo3(get)]
49    pub features: HashMap<String, SpcFeatureDrift>,
50
51    #[pyo3(get)]
52    pub name: String,
53
54    #[pyo3(get)]
55    pub space: String,
56
57    #[pyo3(get)]
58    pub version: String,
59}
60
61#[pymethods]
62#[allow(clippy::new_without_default)]
63impl SpcDriftMap {
64    pub fn __str__(&self) -> String {
65        // serialize the struct to a string
66        PyHelperFuncs::__str__(self)
67    }
68
69    pub fn model_dump_json(&self) -> String {
70        // serialize the struct to a string
71        PyHelperFuncs::__json__(self)
72    }
73
74    #[staticmethod]
75    pub fn model_validate_json(json_string: String) -> Result<SpcDriftMap, UtilError> {
76        // deserialize the string to a struct
77        Ok(serde_json::from_str(&json_string)?)
78    }
79
80    #[pyo3(signature = (path=None))]
81    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, UtilError> {
82        PyHelperFuncs::save_to_json(self, path, FileName::SpcDriftMap.to_str())
83    }
84
85    #[allow(clippy::type_complexity)]
86    pub fn to_numpy<'py>(
87        &self,
88        py: Python<'py>,
89    ) -> Result<
90        (
91            Bound<'py, PyArray2<f64>>,
92            Bound<'py, PyArray2<f64>>,
93            Vec<String>,
94        ),
95        DriftError,
96    > {
97        let (drift_array, sample_array, features) = self.to_array()?;
98
99        Ok((
100            drift_array.into_pyarray(py).to_owned(),
101            sample_array.into_pyarray(py).to_owned(),
102            features,
103        ))
104    }
105}
106
107type ArrayReturn = (Array2<f64>, Array2<f64>, Vec<String>);
108
109impl SpcDriftMap {
110    pub fn new(space: String, name: String, version: String) -> Self {
111        Self {
112            features: HashMap::new(),
113            name,
114            space,
115            version,
116        }
117    }
118
119    pub fn to_array(&self) -> Result<ArrayReturn, DriftError> {
120        let columns = self.features.len();
121        let rows = self.features.values().next().unwrap().samples.len();
122
123        // create empty array
124        let mut drift_array = Array2::<f64>::zeros((rows, columns));
125        let mut sample_array = Array2::<f64>::zeros((rows, columns));
126        let mut features = Vec::new();
127
128        // iterate over the features and insert the drift values
129        for (i, (feature, drift)) in self.features.iter().enumerate() {
130            features.push(feature.clone());
131            drift_array
132                .column_mut(i)
133                .assign(&Array::from(drift.drift.clone()));
134            sample_array
135                .column_mut(i)
136                .assign(&Array::from(drift.samples.clone()));
137        }
138
139        Ok((drift_array, sample_array, features))
140    }
141
142    pub fn add_feature(&mut self, feature: String, drift: SpcFeatureDrift) {
143        self.features.insert(feature, drift);
144    }
145}
146// Drift config to use when calculating drift on a new sample of data