1#![allow(clippy::useless_conversion)]
2use crate::data_utils::DataConverterEnum;
3use crate::drifter::{custom::CustomDrifter, psi::PsiDrifter, spc::SpcDrifter};
4use pyo3::prelude::*;
5use pyo3::types::PyList;
6use pyo3::IntoPyObjectExt;
7use scouter_drift::error::DriftError;
8use scouter_drift::spc::SpcDriftMap;
9use scouter_types::spc::SpcDriftProfile;
10use scouter_types::{
11 custom::{CustomDriftProfile, CustomMetric, CustomMetricDriftConfig},
12 psi::{PsiDriftConfig, PsiDriftMap, PsiDriftProfile},
13 spc::SpcDriftConfig,
14 DataType, DriftProfile, DriftType,
15};
16
17pub enum DriftMap {
18 Spc(SpcDriftMap),
19 Psi(PsiDriftMap),
20}
21
22pub enum DriftConfig {
23 Spc(SpcDriftConfig),
24 Psi(PsiDriftConfig),
25 Custom(CustomMetricDriftConfig),
26}
27
28impl DriftConfig {
29 pub fn spc_config(&self) -> Result<&SpcDriftConfig, DriftError> {
30 match self {
31 DriftConfig::Spc(cfg) => Ok(cfg),
32 _ => Err(DriftError::InvalidConfigError),
33 }
34 }
35
36 pub fn psi_config(&self) -> Result<&PsiDriftConfig, DriftError> {
37 match self {
38 DriftConfig::Psi(cfg) => Ok(cfg),
39 _ => Err(DriftError::InvalidConfigError),
40 }
41 }
42
43 pub fn custom_config(&self) -> Result<&CustomMetricDriftConfig, DriftError> {
44 match self {
45 DriftConfig::Custom(cfg) => Ok(cfg),
46 _ => Err(DriftError::InvalidConfigError),
47 }
48 }
49}
50
51pub enum Drifter {
52 Spc(SpcDrifter),
53 Psi(PsiDrifter),
54 Custom(CustomDrifter),
55}
56
57impl Drifter {
58 fn from_drift_type(drift_type: DriftType) -> Self {
59 match drift_type {
60 DriftType::Spc => Drifter::Spc(SpcDrifter::new()),
61 DriftType::Psi => Drifter::Psi(PsiDrifter::new()),
62 DriftType::Custom => Drifter::Custom(CustomDrifter::new()),
63 }
64 }
65
66 fn create_drift_profile<'py>(
67 &mut self,
68 py: Python<'py>,
69 data: &Bound<'py, PyAny>,
70 data_type: &DataType,
71 config: DriftConfig,
72 ) -> Result<DriftProfile, DriftError> {
73 match self {
74 Drifter::Spc(drifter) => {
77 let data = DataConverterEnum::convert_data(py, data_type, data)?;
78 let profile = drifter.create_drift_profile(data, config.spc_config()?.clone())?;
79 Ok(DriftProfile::Spc(profile))
80 }
81 Drifter::Psi(drifter) => {
82 let data = DataConverterEnum::convert_data(py, data_type, data)?;
83 let profile = drifter.create_drift_profile(data, config.psi_config()?.clone())?;
84 Ok(DriftProfile::Psi(profile))
85 }
86 Drifter::Custom(drifter) => {
87 let data = if data.is_instance_of::<PyList>() {
90 data.extract::<Vec<CustomMetric>>()?
91 } else {
92 let metric = data.extract::<CustomMetric>()?;
93 vec![metric]
94 };
95
96 let profile =
97 drifter.create_drift_profile(config.custom_config()?.clone(), data, None)?;
98 Ok(DriftProfile::Custom(profile))
99 }
100 }
101 }
102
103 fn compute_drift<'py>(
104 &mut self,
105 py: Python<'py>,
106 data: &Bound<'py, PyAny>,
107 data_type: &DataType,
108 profile: &DriftProfile,
109 ) -> Result<DriftMap, DriftError> {
110 match self {
111 Drifter::Spc(drifter) => {
112 let data = DataConverterEnum::convert_data(py, data_type, data)?;
113 let drift_profile = profile.get_spc_profile()?;
114 let drift_map = drifter.compute_drift(data, drift_profile.clone())?;
115 Ok(DriftMap::Spc(drift_map))
116 }
117 Drifter::Psi(drifter) => {
118 let data = DataConverterEnum::convert_data(py, data_type, data)?;
119 let drift_profile = profile.get_psi_profile()?;
120 let drift_map = drifter.compute_drift(data, drift_profile.clone())?;
121 Ok(DriftMap::Psi(drift_map))
122 }
123 Drifter::Custom(_) => {
124 Err(DriftError::NotImplemented)
126 }
127 }
128 }
129}
130
131#[pyclass(name = "Drifter")]
132#[derive(Debug, Default)]
133pub struct PyDrifter {}
134
135#[pymethods]
136impl PyDrifter {
137 #[new]
138 pub fn new() -> Self {
139 Self {}
140 }
141
142 #[pyo3(signature = (data, config=None, data_type=None))]
143 pub fn create_drift_profile<'py>(
144 &self,
145 py: Python<'py>,
146 data: &Bound<'py, PyAny>,
147 config: Option<&Bound<'py, PyAny>>,
148 data_type: Option<&DataType>,
149 ) -> Result<Bound<'py, PyAny>, DriftError> {
150 let (config_helper, drift_type) = if config.is_some() {
153 let obj = config.unwrap();
154 let drift_type = obj.getattr("drift_type")?.extract::<DriftType>()?;
155 let drift_config = match drift_type {
156 DriftType::Spc => {
157 let config = obj.extract::<SpcDriftConfig>()?;
158 DriftConfig::Spc(config)
159 }
160 DriftType::Psi => {
161 let config = obj.extract::<PsiDriftConfig>()?;
162 DriftConfig::Psi(config)
163 }
164 DriftType::Custom => {
165 let config = obj.extract::<CustomMetricDriftConfig>()?;
166 DriftConfig::Custom(config)
167 }
168 };
169 (drift_config, drift_type)
170 } else {
171 (DriftConfig::Spc(SpcDriftConfig::default()), DriftType::Spc)
172 };
173
174 let mut drift_helper = Drifter::from_drift_type(drift_type);
175
176 let data_type = match data_type {
179 Some(data_type) => data_type,
180 None => {
181 let class = data.getattr("__class__")?;
182 let module = class.getattr("__module__")?.str()?.to_string();
183 let name = class.getattr("__name__")?.str()?.to_string();
184 let full_class_name = format!("{module}.{name}");
185
186 &DataType::from_module_name(&full_class_name).unwrap_or(DataType::Unknown)
187 }
189 };
190
191 let profile = drift_helper.create_drift_profile(py, data, data_type, config_helper)?;
192
193 match profile {
194 DriftProfile::Spc(profile) => Ok(profile.into_bound_py_any(py)?),
195 DriftProfile::Psi(profile) => Ok(profile.into_bound_py_any(py)?),
196 DriftProfile::Custom(profile) => Ok(profile.into_bound_py_any(py)?),
197 }
198 }
199
200 #[pyo3(signature = (data, drift_profile, data_type=None))]
201 pub fn compute_drift<'py>(
202 &self,
203 py: Python<'py>,
204 data: &Bound<'py, PyAny>,
205 drift_profile: &Bound<'py, PyAny>,
206 data_type: Option<&DataType>,
207 ) -> Result<Bound<'py, PyAny>, DriftError> {
208 let drift_type = drift_profile
209 .getattr("config")?
210 .getattr("drift_type")?
211 .extract::<DriftType>()?;
212
213 let profile = match drift_type {
214 DriftType::Spc => {
215 let profile = drift_profile.extract::<SpcDriftProfile>()?;
216 DriftProfile::Spc(profile)
217 }
218 DriftType::Psi => {
219 let profile = drift_profile.extract::<PsiDriftProfile>()?;
220 DriftProfile::Psi(profile)
221 }
222 DriftType::Custom => {
223 let profile = drift_profile.extract::<CustomDriftProfile>()?;
224 DriftProfile::Custom(profile)
225 }
226 };
227
228 let data_type = match data_type {
231 Some(data_type) => data_type,
232 None => {
233 let class = data.getattr("__class__")?;
234 let module = class.getattr("__module__")?.str()?.to_string();
235 let name = class.getattr("__name__")?.str()?.to_string();
236 let full_class_name = format!("{module}.{name}");
237
238 &DataType::from_module_name(&full_class_name).unwrap_or(DataType::Unknown)
239 }
241 };
242
243 let mut drift_helper = Drifter::from_drift_type(drift_type);
244
245 let drift_map = drift_helper.compute_drift(py, data, data_type, &profile)?;
246
247 match drift_map {
248 DriftMap::Spc(map) => Ok(map.into_bound_py_any(py)?),
249 DriftMap::Psi(map) => Ok(map.into_bound_py_any(py)?),
250 }
251 }
252}
253
254impl PyDrifter {
255 pub fn internal_create_drift_profile<'py>(
257 &self,
258 py: Python,
259 data: &Bound<'py, PyAny>,
260 config: Option<&Bound<'py, PyAny>>,
261 data_type: Option<&DataType>,
262 ) -> Result<DriftProfile, DriftError> {
263 let (config_helper, drift_type) = if config.is_some() {
266 let obj = config.unwrap();
267 let drift_type = obj.getattr("drift_type")?.extract::<DriftType>()?;
268 let drift_config = match drift_type {
269 DriftType::Spc => {
270 let config = obj.extract::<SpcDriftConfig>()?;
271 DriftConfig::Spc(config)
272 }
273 DriftType::Psi => {
274 let config = obj.extract::<PsiDriftConfig>()?;
275 DriftConfig::Psi(config)
276 }
277 DriftType::Custom => {
278 let config = obj.extract::<CustomMetricDriftConfig>()?;
279 DriftConfig::Custom(config)
280 }
281 };
282 (drift_config, drift_type)
283 } else {
284 (DriftConfig::Spc(SpcDriftConfig::default()), DriftType::Spc)
285 };
286
287 let mut drift_helper = Drifter::from_drift_type(drift_type);
288
289 let data_type = match data_type {
292 Some(data_type) => data_type,
293 None => {
294 let class = data.getattr("__class__")?;
295 let module = class.getattr("__module__")?.str()?.to_string();
296 let name = class.getattr("__name__")?.str()?.to_string();
297 let full_class_name = format!("{module}.{name}");
298
299 &DataType::from_module_name(&full_class_name).unwrap_or(DataType::Unknown)
300 }
302 };
303
304 drift_helper.create_drift_profile(py, data, data_type, config_helper)
305 }
306}