1use crate::data_utils::DataConverterEnum;
2use crate::drifter::{custom::CustomDrifter, llm::LLMDrifter, psi::PsiDrifter, spc::SpcDrifter};
3use pyo3::prelude::*;
4use pyo3::types::PyList;
5use pyo3::IntoPyObjectExt;
6use scouter_drift::error::DriftError;
7use scouter_drift::spc::SpcDriftMap;
8use scouter_types::llm::{LLMDriftMap, LLMMetric};
9use scouter_types::spc::SpcDriftProfile;
10use scouter_types::LLMRecord;
11use scouter_types::{
12 custom::{CustomDriftProfile, CustomMetric, CustomMetricDriftConfig},
13 llm::{LLMDriftConfig, LLMDriftProfile},
14 psi::{PsiDriftConfig, PsiDriftMap, PsiDriftProfile},
15 spc::SpcDriftConfig,
16 DataType, DriftProfile, DriftType,
17};
18use std::fmt::Debug;
19use std::sync::Arc;
20use std::sync::RwLock;
21pub enum DriftMap {
22 Spc(SpcDriftMap),
23 Psi(PsiDriftMap),
24 LLM(LLMDriftMap),
25}
26
27pub enum DriftConfig {
28 Spc(Arc<RwLock<SpcDriftConfig>>),
29 Psi(Arc<RwLock<PsiDriftConfig>>),
30 LLM(LLMDriftConfig),
31 Custom(CustomMetricDriftConfig),
32}
33
34impl DriftConfig {
35 pub fn spc_config(&self) -> Result<Arc<RwLock<SpcDriftConfig>>, DriftError> {
36 match self {
37 DriftConfig::Spc(cfg) => Ok(cfg.clone()),
38 _ => Err(DriftError::InvalidConfigError),
39 }
40 }
41
42 pub fn psi_config(&self) -> Result<Arc<RwLock<PsiDriftConfig>>, DriftError> {
43 match self {
44 DriftConfig::Psi(cfg) => Ok(cfg.clone()),
45 _ => Err(DriftError::InvalidConfigError),
46 }
47 }
48
49 pub fn custom_config(&self) -> Result<CustomMetricDriftConfig, DriftError> {
50 match self {
51 DriftConfig::Custom(cfg) => Ok(cfg.clone()),
52 _ => Err(DriftError::InvalidConfigError),
53 }
54 }
55
56 pub fn llm_config(&self) -> Result<LLMDriftConfig, DriftError> {
57 match self {
58 DriftConfig::LLM(cfg) => Ok(cfg.clone()),
59 _ => Err(DriftError::InvalidConfigError),
60 }
61 }
62}
63
64pub enum Drifter {
65 Spc(SpcDrifter),
66 Psi(PsiDrifter),
67 Custom(CustomDrifter),
68 LLM(LLMDrifter),
69}
70
71impl Debug for Drifter {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Drifter::Spc(_) => write!(f, "SpcDrifter"),
75 Drifter::Psi(_) => write!(f, "PsiDrifter"),
76 Drifter::Custom(_) => write!(f, "CustomDrifter"),
77 Drifter::LLM(_) => write!(f, "LLMDrifter"),
78 }
79 }
80}
81
82impl Drifter {
83 fn from_drift_type(drift_type: DriftType) -> Result<Self, DriftError> {
84 match drift_type {
85 DriftType::Spc => Ok(Drifter::Spc(SpcDrifter::new())),
86 DriftType::Psi => Ok(Drifter::Psi(PsiDrifter::new())),
87 DriftType::Custom => Ok(Drifter::Custom(CustomDrifter::new())),
88 DriftType::LLM => Ok(Drifter::LLM(LLMDrifter::new())),
89 }
90 }
91
92 fn create_drift_profile<'py>(
93 &mut self,
94 py: Python<'py>,
95 data: &Bound<'py, PyAny>,
96 data_type: &DataType,
97 config: DriftConfig,
98 workflow: Option<Bound<'py, PyAny>>,
99 ) -> Result<DriftProfile, DriftError> {
100 match self {
101 Drifter::Spc(drifter) => {
104 let data = DataConverterEnum::convert_data(py, data_type, data)?;
105 let profile = drifter.create_drift_profile(data, config.spc_config()?)?;
106 Ok(DriftProfile::Spc(profile))
107 }
108 Drifter::Psi(drifter) => {
109 let data = DataConverterEnum::convert_data(py, data_type, data)?;
110 let profile = drifter.create_drift_profile(data, config.psi_config()?)?;
111 Ok(DriftProfile::Psi(profile))
112 }
113 Drifter::Custom(drifter) => {
114 let data = if data.is_instance_of::<PyList>() {
117 data.extract::<Vec<CustomMetric>>()?
118 } else {
119 let metric = data.extract::<CustomMetric>()?;
120 vec![metric]
121 };
122
123 let profile = drifter.create_drift_profile(config.custom_config()?, data)?;
124 Ok(DriftProfile::Custom(profile))
125 }
126 Drifter::LLM(drifter) => {
127 let metrics = if data.is_instance_of::<PyList>() {
129 data.extract::<Vec<LLMMetric>>()?
130 } else {
131 let metric = data.extract::<LLMMetric>()?;
132 vec![metric]
133 };
134 let profile =
135 drifter.create_drift_profile(config.llm_config()?, metrics, workflow)?;
136 Ok(DriftProfile::LLM(profile))
137 }
138 }
139 }
140
141 fn compute_drift<'py>(
142 &mut self,
143 py: Python<'py>,
144 data: &Bound<'py, PyAny>,
145 data_type: &DataType,
146 profile: &DriftProfile,
147 ) -> Result<DriftMap, DriftError> {
148 match self {
149 Drifter::Spc(drifter) => {
150 let data = DataConverterEnum::convert_data(py, data_type, data)?;
151 let drift_profile = profile.get_spc_profile()?;
152 let drift_map = drifter.compute_drift(data, drift_profile.clone())?;
153 Ok(DriftMap::Spc(drift_map))
154 }
155 Drifter::Psi(drifter) => {
156 let data = DataConverterEnum::convert_data(py, data_type, data)?;
157 let drift_profile = profile.get_psi_profile()?;
158 let drift_map = drifter.compute_drift(data, drift_profile.clone())?;
159 Ok(DriftMap::Psi(drift_map))
160 }
161 Drifter::Custom(_) => {
162 Err(DriftError::NotImplemented)
164 }
165
166 Drifter::LLM(drifter) => {
167 let data = if data.is_instance_of::<PyList>() {
169 data.extract::<Vec<LLMRecord>>()?
170 } else {
171 let metric = data.extract::<LLMRecord>()?;
172 vec![metric]
173 };
174 let records = drifter.compute_drift(data, profile.get_llm_profile()?)?;
175
176 Ok(DriftMap::LLM(LLMDriftMap { records }))
177 }
178 }
179 }
180}
181
182#[pyclass(name = "Drifter")]
183#[derive(Debug, Default)]
184pub struct PyDrifter {}
185
186#[pymethods]
187impl PyDrifter {
188 #[new]
189 pub fn new() -> Self {
190 Self {}
191 }
192
193 #[pyo3(signature = (data, config=None, data_type=None, workflow=None))]
203 pub fn create_drift_profile<'py>(
204 &self,
205 py: Python<'py>,
206 data: &Bound<'py, PyAny>,
207 config: Option<&Bound<'py, PyAny>>,
208 data_type: Option<&DataType>,
209 workflow: Option<Bound<'py, PyAny>>,
210 ) -> Result<Bound<'py, PyAny>, DriftError> {
211 let (config_helper, drift_type) = if let Some(obj) = config {
214 let drift_type = obj.getattr("drift_type")?.extract::<DriftType>()?;
215 let drift_config = match drift_type {
216 DriftType::Spc => {
217 let config = obj.extract::<SpcDriftConfig>()?;
218 DriftConfig::Spc(Arc::new(config.into()))
219 }
220 DriftType::Psi => {
221 let config = obj.extract::<PsiDriftConfig>()?;
222 DriftConfig::Psi(Arc::new(config.into()))
223 }
224 DriftType::Custom => {
225 let config = obj.extract::<CustomMetricDriftConfig>()?;
226 DriftConfig::Custom(config)
227 }
228 DriftType::LLM => {
229 let config = obj.extract::<LLMDriftConfig>()?;
230 DriftConfig::LLM(config)
231 }
232 };
233 (drift_config, drift_type)
234 } else {
235 (
236 DriftConfig::Spc(Arc::new(SpcDriftConfig::default().into())),
237 DriftType::Spc,
238 )
239 };
240
241 let mut drift_helper = Drifter::from_drift_type(drift_type)?;
242
243 let data_type = match data_type {
246 Some(data_type) => data_type,
247 None => {
248 let class = data.getattr("__class__")?;
249 let module = class.getattr("__module__")?.str()?.to_string();
250 let name = class.getattr("__name__")?.str()?.to_string();
251 let full_class_name = format!("{module}.{name}");
252
253 &DataType::from_module_name(&full_class_name).unwrap_or(DataType::Unknown)
254 }
256 };
257
258 let profile =
259 drift_helper.create_drift_profile(py, data, data_type, config_helper, workflow)?;
260
261 match profile {
262 DriftProfile::Spc(profile) => Ok(profile.into_bound_py_any(py)?),
263 DriftProfile::Psi(profile) => Ok(profile.into_bound_py_any(py)?),
264 DriftProfile::Custom(profile) => Ok(profile.into_bound_py_any(py)?),
265 DriftProfile::LLM(profile) => Ok(profile.into_bound_py_any(py)?),
266 }
267 }
268
269 #[pyo3(signature = (config, metrics, workflow=None))]
272 pub fn create_llm_drift_profile<'py>(
273 &mut self,
274 py: Python<'py>,
275 config: LLMDriftConfig,
276 metrics: Vec<LLMMetric>,
277 workflow: Option<Bound<'py, PyAny>>,
278 ) -> Result<Bound<'py, PyAny>, DriftError> {
279 let profile = LLMDriftProfile::new(config, metrics, workflow)?;
280 Ok(profile.into_bound_py_any(py)?)
281 }
282
283 #[pyo3(signature = (data, drift_profile, data_type=None))]
284 pub fn compute_drift<'py>(
285 &self,
286 py: Python<'py>,
287 data: &Bound<'py, PyAny>,
288 drift_profile: &Bound<'py, PyAny>,
289 data_type: Option<&DataType>,
290 ) -> Result<Bound<'py, PyAny>, DriftError> {
291 let drift_type = drift_profile
292 .getattr("config")?
293 .getattr("drift_type")?
294 .extract::<DriftType>()?;
295
296 let profile = match drift_type {
297 DriftType::Spc => {
298 let profile = drift_profile.extract::<SpcDriftProfile>()?;
299 DriftProfile::Spc(profile)
300 }
301 DriftType::Psi => {
302 let profile = drift_profile.extract::<PsiDriftProfile>()?;
303 DriftProfile::Psi(profile)
304 }
305 DriftType::Custom => {
306 let profile = drift_profile.extract::<CustomDriftProfile>()?;
307 DriftProfile::Custom(profile)
308 }
309 DriftType::LLM => {
310 let profile = drift_profile.extract::<LLMDriftProfile>()?;
311 DriftProfile::LLM(profile)
312 }
313 };
314
315 let data_type = match data_type {
320 Some(data_type) => data_type,
321 None => {
322 if drift_type == DriftType::LLM {
323 &DataType::LLM
325 } else {
326 let class = data.getattr("__class__")?;
327 let module = class.getattr("__module__")?.str()?.to_string();
328 let name = class.getattr("__name__")?.str()?.to_string();
329 let full_class_name = format!("{module}.{name}");
330
331 &DataType::from_module_name(&full_class_name).unwrap_or(DataType::Unknown)
333 }
334 }
335 };
336
337 let mut drift_helper = Drifter::from_drift_type(drift_type)?;
338
339 let drift_map = drift_helper.compute_drift(py, data, data_type, &profile)?;
340
341 match drift_map {
342 DriftMap::Spc(map) => Ok(map.into_bound_py_any(py)?),
343 DriftMap::Psi(map) => Ok(map.into_bound_py_any(py)?),
344 DriftMap::LLM(map) => Ok(map.into_bound_py_any(py)?),
345 }
346 }
347}