scouter_client/drifter/
scouter.rs

1use crate::data_utils::DataConverterEnum;
2use crate::drifter::{
3    custom::CustomDrifter, llm::ClientLLMDrifter, psi::PsiDrifter, spc::SpcDrifter,
4};
5use pyo3::prelude::*;
6use pyo3::types::PyList;
7use pyo3::IntoPyObjectExt;
8use scouter_drift::error::DriftError;
9use scouter_drift::spc::SpcDriftMap;
10use scouter_types::llm::{LLMDriftMap, LLMDriftMetric};
11use scouter_types::spc::SpcDriftProfile;
12use scouter_types::LLMRecord;
13use scouter_types::{
14    custom::{CustomDriftProfile, CustomMetric, CustomMetricDriftConfig},
15    llm::{LLMDriftConfig, LLMDriftProfile},
16    psi::{PsiDriftConfig, PsiDriftMap, PsiDriftProfile},
17    spc::SpcDriftConfig,
18    DataType, DriftProfile, DriftType,
19};
20use std::fmt::Debug;
21use std::sync::Arc;
22use std::sync::RwLock;
23
24pub enum DriftMap {
25    Spc(SpcDriftMap),
26    Psi(PsiDriftMap),
27    LLM(LLMDriftMap),
28}
29
30pub enum DriftConfig {
31    Spc(Arc<RwLock<SpcDriftConfig>>),
32    Psi(Arc<RwLock<PsiDriftConfig>>),
33    LLM(LLMDriftConfig),
34    Custom(CustomMetricDriftConfig),
35}
36
37impl DriftConfig {
38    pub fn spc_config(&self) -> Result<Arc<RwLock<SpcDriftConfig>>, DriftError> {
39        match self {
40            DriftConfig::Spc(cfg) => Ok(cfg.clone()),
41            _ => Err(DriftError::InvalidConfigError),
42        }
43    }
44
45    pub fn psi_config(&self) -> Result<Arc<RwLock<PsiDriftConfig>>, DriftError> {
46        match self {
47            DriftConfig::Psi(cfg) => Ok(cfg.clone()),
48            _ => Err(DriftError::InvalidConfigError),
49        }
50    }
51
52    pub fn custom_config(&self) -> Result<CustomMetricDriftConfig, DriftError> {
53        match self {
54            DriftConfig::Custom(cfg) => Ok(cfg.clone()),
55            _ => Err(DriftError::InvalidConfigError),
56        }
57    }
58
59    pub fn llm_config(&self) -> Result<LLMDriftConfig, DriftError> {
60        match self {
61            DriftConfig::LLM(cfg) => Ok(cfg.clone()),
62            _ => Err(DriftError::InvalidConfigError),
63        }
64    }
65}
66
67pub enum Drifter {
68    Spc(SpcDrifter),
69    Psi(PsiDrifter),
70    Custom(CustomDrifter),
71    LLM(ClientLLMDrifter),
72}
73
74impl Debug for Drifter {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        match self {
77            Drifter::Spc(_) => write!(f, "SpcDrifter"),
78            Drifter::Psi(_) => write!(f, "PsiDrifter"),
79            Drifter::Custom(_) => write!(f, "CustomDrifter"),
80            Drifter::LLM(_) => write!(f, "LLMDrifter"),
81        }
82    }
83}
84
85impl Drifter {
86    fn from_drift_type(drift_type: DriftType) -> Result<Self, DriftError> {
87        match drift_type {
88            DriftType::Spc => Ok(Drifter::Spc(SpcDrifter::new())),
89            DriftType::Psi => Ok(Drifter::Psi(PsiDrifter::new())),
90            DriftType::Custom => Ok(Drifter::Custom(CustomDrifter::new())),
91            DriftType::LLM => Ok(Drifter::LLM(ClientLLMDrifter::new())),
92        }
93    }
94
95    fn create_drift_profile<'py>(
96        &mut self,
97        py: Python<'py>,
98        data: &Bound<'py, PyAny>,
99        data_type: &DataType,
100        config: DriftConfig,
101        workflow: Option<Bound<'py, PyAny>>,
102    ) -> Result<DriftProfile, DriftError> {
103        match self {
104            // Before creating the profile, we first need to do a rough split of the data into string and numeric data types before
105            // passing it to the drifter
106            Drifter::Spc(drifter) => {
107                let data = DataConverterEnum::convert_data(py, data_type, data)?;
108                let profile = drifter.create_drift_profile(data, config.spc_config()?)?;
109                Ok(DriftProfile::Spc(profile))
110            }
111            Drifter::Psi(drifter) => {
112                let data = DataConverterEnum::convert_data(py, data_type, data)?;
113                let profile = drifter.create_drift_profile(data, config.psi_config()?)?;
114                Ok(DriftProfile::Psi(profile))
115            }
116            Drifter::Custom(drifter) => {
117                // check if data is pylist. If it is, convert to Vec<CustomMetric>
118                // if not extract to CustomMetric and add to vec
119                let data = if data.is_instance_of::<PyList>() {
120                    data.extract::<Vec<CustomMetric>>()?
121                } else {
122                    let metric = data.extract::<CustomMetric>()?;
123                    vec![metric]
124                };
125
126                let profile = drifter.create_drift_profile(config.custom_config()?, data)?;
127                Ok(DriftProfile::Custom(profile))
128            }
129            Drifter::LLM(drifter) => {
130                // LLM drift profiles are created separately, so we will handle this in the create_llm_drift_profile method
131                let metrics = if data.is_instance_of::<PyList>() {
132                    data.extract::<Vec<LLMDriftMetric>>()?
133                } else {
134                    let metric = data.extract::<LLMDriftMetric>()?;
135                    vec![metric]
136                };
137                let profile =
138                    drifter.create_drift_profile(config.llm_config()?, metrics, workflow)?;
139                Ok(DriftProfile::LLM(profile))
140            }
141        }
142    }
143
144    fn compute_drift<'py>(
145        &mut self,
146        py: Python<'py>,
147        data: &Bound<'py, PyAny>,
148        data_type: &DataType,
149        profile: &DriftProfile,
150    ) -> Result<DriftMap, DriftError> {
151        match self {
152            Drifter::Spc(drifter) => {
153                let data = DataConverterEnum::convert_data(py, data_type, data)?;
154                let drift_profile = profile.get_spc_profile()?;
155                let drift_map = drifter.compute_drift(data, drift_profile.clone())?;
156                Ok(DriftMap::Spc(drift_map))
157            }
158            Drifter::Psi(drifter) => {
159                let data = DataConverterEnum::convert_data(py, data_type, data)?;
160                let drift_profile = profile.get_psi_profile()?;
161                let drift_map = drifter.compute_drift(data, drift_profile.clone())?;
162                Ok(DriftMap::Psi(drift_map))
163            }
164            Drifter::Custom(_) => {
165                // check if data is pylist. If it is, convert to Vec<CustomMetric>
166                Err(DriftError::NotImplemented)
167            }
168
169            Drifter::LLM(drifter) => {
170                // extract data to be Vec<LLMRecord>
171                let data = if data.is_instance_of::<PyList>() {
172                    data.extract::<Vec<LLMRecord>>()?
173                } else {
174                    let metric = data.extract::<LLMRecord>()?;
175                    vec![metric]
176                };
177                let records = drifter.compute_drift(data, profile.get_llm_profile()?)?;
178
179                Ok(DriftMap::LLM(LLMDriftMap { records }))
180            }
181        }
182    }
183}
184
185#[pyclass(name = "Drifter")]
186#[derive(Debug, Default)]
187pub struct PyDrifter {}
188
189#[pymethods]
190impl PyDrifter {
191    #[new]
192    pub fn new() -> Self {
193        Self {}
194    }
195
196    /// This method is used to create a drift profile based on the data and config provided
197    /// It will automatically infer the data type if not provided
198    /// If the config is not provided, it will create a default config based on the drift type
199    /// The data can be a numpy array, pandas dataframe, or pyarrow table, Vec<CustomMetric>, or Vec<LLMDriftMetric>
200    /// ## Arguments:
201    /// - `data`: The data to create the drift profile from. This can be a numpy array, pandas dataframe, pyarrow table, Vec<CustomMetric>, or Vec<LLMDriftMetric>.
202    /// - `config`: The configuration for the drift profile. This is optional and if not provided, a default configuration will be created based on the drift type.
203    /// - `data_type`: The type of the data. This is optional and if not provided, it will be inferred from the data class name.
204    /// - `workflow`: An optional workflow to be used with the drift profile. This is only applicable for LLM drift profiles.
205    #[pyo3(signature = (data, config=None, data_type=None, workflow=None))]
206    pub fn create_drift_profile<'py>(
207        &self,
208        py: Python<'py>,
209        data: &Bound<'py, PyAny>,
210        config: Option<&Bound<'py, PyAny>>,
211        data_type: Option<&DataType>,
212        workflow: Option<Bound<'py, PyAny>>,
213    ) -> Result<Bound<'py, PyAny>, DriftError> {
214        // if config is None, then we need to create a default config
215
216        let (config_helper, drift_type) = if let Some(obj) = config {
217            let drift_type = obj.getattr("drift_type")?.extract::<DriftType>()?;
218            let drift_config = match drift_type {
219                DriftType::Spc => {
220                    let config = obj.extract::<SpcDriftConfig>()?;
221                    DriftConfig::Spc(Arc::new(config.into()))
222                }
223                DriftType::Psi => {
224                    let config = obj.extract::<PsiDriftConfig>()?;
225                    DriftConfig::Psi(Arc::new(config.into()))
226                }
227                DriftType::Custom => {
228                    let config = obj.extract::<CustomMetricDriftConfig>()?;
229                    DriftConfig::Custom(config)
230                }
231                DriftType::LLM => {
232                    let config = obj.extract::<LLMDriftConfig>()?;
233                    DriftConfig::LLM(config)
234                }
235            };
236            (drift_config, drift_type)
237        } else {
238            (
239                DriftConfig::Spc(Arc::new(SpcDriftConfig::default().into())),
240                DriftType::Spc,
241            )
242        };
243
244        let mut drift_helper = Drifter::from_drift_type(drift_type)?;
245
246        // if data_type is None, try to infer it from the class name
247        // This is for handling, numpy, pandas, pyarrow
248        let data_type = match data_type {
249            Some(data_type) => data_type,
250            None => {
251                let class = data.getattr("__class__")?;
252                let module = class.getattr("__module__")?.str()?.to_string();
253                let name = class.getattr("__name__")?.str()?.to_string();
254                let full_class_name = format!("{module}.{name}");
255
256                &DataType::from_module_name(&full_class_name).unwrap_or(DataType::Unknown)
257                // for handling custom
258            }
259        };
260
261        let profile =
262            drift_helper.create_drift_profile(py, data, data_type, config_helper, workflow)?;
263
264        match profile {
265            DriftProfile::Spc(profile) => Ok(profile.into_bound_py_any(py)?),
266            DriftProfile::Psi(profile) => Ok(profile.into_bound_py_any(py)?),
267            DriftProfile::Custom(profile) => Ok(profile.into_bound_py_any(py)?),
268            DriftProfile::LLM(profile) => Ok(profile.into_bound_py_any(py)?),
269        }
270    }
271
272    // Specific method for creating LLM drift profiles
273    // This is to avoid confusion with the other drifters
274    #[pyo3(signature = (config, metrics, workflow=None))]
275    pub fn create_llm_drift_profile<'py>(
276        &mut self,
277        py: Python<'py>,
278        config: LLMDriftConfig,
279        metrics: Vec<LLMDriftMetric>,
280        workflow: Option<Bound<'py, PyAny>>,
281    ) -> Result<Bound<'py, PyAny>, DriftError> {
282        let profile = LLMDriftProfile::new(config, metrics, workflow)?;
283        Ok(profile.into_bound_py_any(py)?)
284    }
285
286    #[pyo3(signature = (data, drift_profile, data_type=None))]
287    pub fn compute_drift<'py>(
288        &self,
289        py: Python<'py>,
290        data: &Bound<'py, PyAny>,
291        drift_profile: &Bound<'py, PyAny>,
292        data_type: Option<&DataType>,
293    ) -> Result<Bound<'py, PyAny>, DriftError> {
294        let drift_type = drift_profile
295            .getattr("config")?
296            .getattr("drift_type")?
297            .extract::<DriftType>()?;
298
299        let profile = match drift_type {
300            DriftType::Spc => {
301                let profile = drift_profile.extract::<SpcDriftProfile>()?;
302                DriftProfile::Spc(profile)
303            }
304            DriftType::Psi => {
305                let profile = drift_profile.extract::<PsiDriftProfile>()?;
306                DriftProfile::Psi(profile)
307            }
308            DriftType::Custom => {
309                let profile = drift_profile.extract::<CustomDriftProfile>()?;
310                DriftProfile::Custom(profile)
311            }
312            DriftType::LLM => {
313                let profile = drift_profile.extract::<LLMDriftProfile>()?;
314                DriftProfile::LLM(profile)
315            }
316        };
317
318        // if data_type is None, try to infer it from the class name
319        // This is for handling, numpy, pandas, pyarrow
320        // skip if drift_type is LLM, as it will be handled separately
321
322        let data_type = match data_type {
323            Some(data_type) => data_type,
324            None => {
325                if drift_type == DriftType::LLM {
326                    // For LLM, we will handle it separately in the create_llm_drift_profile method
327                    &DataType::LLM
328                } else {
329                    let class = data.getattr("__class__")?;
330                    let module = class.getattr("__module__")?.str()?.to_string();
331                    let name = class.getattr("__name__")?.str()?.to_string();
332                    let full_class_name = format!("{module}.{name}");
333
334                    // for handling custom
335                    &DataType::from_module_name(&full_class_name).unwrap_or(DataType::Unknown)
336                }
337            }
338        };
339
340        let mut drift_helper = Drifter::from_drift_type(drift_type)?;
341
342        let drift_map = drift_helper.compute_drift(py, data, data_type, &profile)?;
343
344        match drift_map {
345            DriftMap::Spc(map) => Ok(map.into_bound_py_any(py)?),
346            DriftMap::Psi(map) => Ok(map.into_bound_py_any(py)?),
347            DriftMap::LLM(map) => Ok(map.into_bound_py_any(py)?),
348        }
349    }
350}
351
352impl PyDrifter {
353    /// Reproduction of `create_llm_drift_profile` but allows for passing a runtime
354    /// This is used in opsml to allow passing the Opsml runtime
355    pub fn create_llm_drift_profile_with_runtime<'py>(
356        &mut self,
357        py: Python<'py>,
358        config: LLMDriftConfig,
359        metrics: Vec<LLMDriftMetric>,
360        workflow: Option<Bound<'py, PyAny>>,
361        runtime: Arc<tokio::runtime::Runtime>,
362    ) -> Result<Bound<'py, PyAny>, DriftError> {
363        let profile = LLMDriftProfile::new_with_runtime(config, metrics, workflow, runtime)?;
364        Ok(profile.into_bound_py_any(py)?)
365    }
366}