Skip to main content

scouter_types/
drift.rs

1use crate::custom::CustomDriftProfile;
2use crate::error::ProfileError;
3use crate::genai::profile::GenAIEvalProfile;
4use crate::psi::PsiDriftProfile;
5use crate::spc::SpcDriftProfile;
6use crate::util::ProfileBaseArgs;
7use crate::{AlertDispatchConfig, ProfileArgs};
8use crate::{FileName, PyHelperFuncs};
9use pyo3::prelude::*;
10use pyo3::IntoPyObjectExt;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::fmt::Display;
14use std::ops::Deref;
15use std::path::PathBuf;
16use std::str::FromStr;
17use strum_macros::EnumIter;
18#[pyclass(eq)]
19#[derive(Debug, EnumIter, PartialEq, Serialize, Deserialize, Clone, Default, Eq, Hash)]
20pub enum DriftType {
21    #[default]
22    Spc,
23    Psi,
24    Custom,
25    GenAI,
26}
27
28#[pymethods]
29impl DriftType {
30    #[staticmethod]
31    pub fn from_value(value: &str) -> Option<DriftType> {
32        match value.to_lowercase().as_str() {
33            "spc" => Some(DriftType::Spc),
34            "psi" => Some(DriftType::Psi),
35            "custom" => Some(DriftType::Custom),
36            "genai" => Some(DriftType::GenAI),
37            _ => None,
38        }
39    }
40
41    #[getter]
42    pub fn to_string(&self) -> &str {
43        match self {
44            DriftType::Spc => "Spc",
45            DriftType::Psi => "Psi",
46            DriftType::Custom => "Custom",
47            DriftType::GenAI => "GenAI",
48        }
49    }
50}
51
52impl Deref for DriftType {
53    type Target = str;
54
55    fn deref(&self) -> &Self::Target {
56        match self {
57            DriftType::Spc => "spc",
58            DriftType::Psi => "psi",
59            DriftType::Custom => "custom",
60            DriftType::GenAI => "genai",
61        }
62    }
63}
64
65impl FromStr for DriftType {
66    type Err = ProfileError;
67
68    fn from_str(value: &str) -> Result<Self, Self::Err> {
69        match value.to_lowercase().as_str() {
70            "spc" => Ok(DriftType::Spc),
71            "psi" => Ok(DriftType::Psi),
72            "custom" => Ok(DriftType::Custom),
73            "genai" => Ok(DriftType::GenAI),
74            _ => Err(ProfileError::InvalidDriftTypeError),
75        }
76    }
77}
78
79impl Display for DriftType {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        match self {
82            DriftType::Spc => write!(f, "Spc"),
83            DriftType::Psi => write!(f, "Psi"),
84            DriftType::Custom => write!(f, "Custom"),
85            DriftType::GenAI => write!(f, "LLM"),
86        }
87    }
88}
89
90pub struct DriftArgs {
91    pub name: String,
92    pub space: String,
93    pub version: String,
94    pub dispatch_config: AlertDispatchConfig,
95}
96
97#[pyclass]
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99pub enum DriftProfile {
100    Spc(SpcDriftProfile),
101    Psi(PsiDriftProfile),
102    Custom(CustomDriftProfile),
103    GenAI(GenAIEvalProfile),
104}
105
106#[pymethods]
107impl DriftProfile {
108    #[getter]
109    pub fn profile<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, ProfileError> {
110        match self {
111            DriftProfile::Spc(profile) => Ok(profile.clone().into_bound_py_any(py)?),
112            DriftProfile::Psi(profile) => Ok(profile.clone().into_bound_py_any(py)?),
113            DriftProfile::Custom(profile) => Ok(profile.clone().into_bound_py_any(py)?),
114            DriftProfile::GenAI(profile) => Ok(profile.clone().into_bound_py_any(py)?),
115        }
116    }
117}
118
119impl DriftProfile {
120    /// Create a new DriftProfile from a DriftType and a profile string
121    /// This function will map the drift type to the correct profile type to load
122    ///
123    /// # Arguments
124    ///
125    /// * `drift_type` - DriftType enum
126    /// * `profile` - Profile string
127    ///
128    /// # Returns
129    ///
130    /// * `Result<Self>` - Result of DriftProfile
131    pub fn from_str(drift_type: &DriftType, profile: &str) -> Result<Self, ProfileError> {
132        match drift_type {
133            DriftType::Spc => {
134                let profile = serde_json::from_str(profile)?;
135                Ok(DriftProfile::Spc(profile))
136            }
137            DriftType::Psi => {
138                let profile = serde_json::from_str(profile)?;
139                Ok(DriftProfile::Psi(profile))
140            }
141            DriftType::Custom => {
142                let profile = serde_json::from_str(profile)?;
143                Ok(DriftProfile::Custom(profile))
144            }
145            DriftType::GenAI => {
146                let profile = serde_json::from_str(profile)?;
147                Ok(DriftProfile::GenAI(profile))
148            }
149        }
150    }
151
152    /// Get the base arguments for a drift profile
153    pub fn get_base_args(&self) -> ProfileArgs {
154        match self {
155            DriftProfile::Spc(profile) => profile.get_base_args(),
156            DriftProfile::Psi(profile) => profile.get_base_args(),
157            DriftProfile::Custom(profile) => profile.get_base_args(),
158            DriftProfile::GenAI(profile) => profile.get_base_args(),
159        }
160    }
161
162    pub fn to_value(&self) -> serde_json::Value {
163        match self {
164            DriftProfile::Spc(profile) => profile.to_value(),
165            DriftProfile::Psi(profile) => profile.to_value(),
166            DriftProfile::Custom(profile) => profile.to_value(),
167            DriftProfile::GenAI(profile) => profile.to_value(),
168        }
169    }
170
171    /// Create a new DriftProfile from a value (this is used by scouter-server)
172    /// This function will map the drift type to the correct profile type to load
173    ///
174    /// # Arguments
175    ///
176    /// * `body` - Request body
177    /// * `drift_type` - Drift type string
178    ///
179    pub fn from_value(body: serde_json::Value) -> Result<Self, ProfileError> {
180        let drift_type = body["config"]["drift_type"].as_str().unwrap();
181
182        let drift_type = DriftType::from_str(drift_type)?;
183
184        match drift_type {
185            DriftType::Spc => {
186                let profile = serde_json::from_value(body)?;
187                Ok(DriftProfile::Spc(profile))
188            }
189            DriftType::Psi => {
190                let profile = serde_json::from_value(body)?;
191                Ok(DriftProfile::Psi(profile))
192            }
193            DriftType::Custom => {
194                let profile = serde_json::from_value(body)?;
195                Ok(DriftProfile::Custom(profile))
196            }
197            DriftType::GenAI => {
198                let profile = serde_json::from_value(body)?;
199                Ok(DriftProfile::GenAI(profile))
200            }
201        }
202    }
203
204    pub fn from_python(profile: &Bound<'_, PyAny>) -> Result<Self, ProfileError> {
205        let drift_type: DriftType = profile.getattr("drift_type")?.extract::<DriftType>()?;
206        match drift_type {
207            DriftType::Spc => {
208                let profile = profile.extract::<SpcDriftProfile>()?;
209                Ok(DriftProfile::Spc(profile))
210            }
211            DriftType::Psi => {
212                let profile = profile.extract::<PsiDriftProfile>()?;
213                Ok(DriftProfile::Psi(profile))
214            }
215            DriftType::Custom => {
216                let profile = profile.extract::<CustomDriftProfile>()?;
217                Ok(DriftProfile::Custom(profile))
218            }
219            DriftType::GenAI => {
220                let profile = profile.extract::<GenAIEvalProfile>()?;
221                Ok(DriftProfile::GenAI(profile))
222            }
223        }
224    }
225
226    pub fn get_spc_profile(&self) -> Result<&SpcDriftProfile, ProfileError> {
227        match self {
228            DriftProfile::Spc(profile) => Ok(profile),
229            _ => Err(ProfileError::InvalidDriftTypeError),
230        }
231    }
232
233    pub fn get_psi_profile(&self) -> Result<&PsiDriftProfile, ProfileError> {
234        match self {
235            DriftProfile::Psi(profile) => Ok(profile),
236            _ => Err(ProfileError::InvalidDriftTypeError),
237        }
238    }
239
240    pub fn get_genai_profile(&self) -> Result<&GenAIEvalProfile, ProfileError> {
241        match self {
242            DriftProfile::GenAI(profile) => Ok(profile),
243            _ => Err(ProfileError::InvalidDriftTypeError),
244        }
245    }
246
247    pub fn drift_type(&self) -> DriftType {
248        match self {
249            DriftProfile::Spc(_) => DriftType::Spc,
250            DriftProfile::Psi(_) => DriftType::Psi,
251            DriftProfile::Custom(_) => DriftType::Custom,
252            DriftProfile::GenAI(_) => DriftType::GenAI,
253        }
254    }
255
256    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
257        Ok(PyHelperFuncs::save_to_json(
258            self,
259            path,
260            FileName::DriftProfile.to_str(),
261        )?)
262    }
263
264    pub fn load_from_json(path: PathBuf) -> Result<Self, ProfileError> {
265        let file = std::fs::read_to_string(&path)?;
266        Ok(serde_json::from_str(&file)?)
267    }
268
269    /// load a profile into the DriftProfile enum from path
270    ///
271    /// # Arguments
272    /// * `path` - Path to the profile
273    ///
274    /// # Returns
275    /// * `Result<Self>` - Result of DriftProfile
276    pub fn from_profile_path(path: PathBuf) -> Result<Self, ProfileError> {
277        let profile = std::fs::read_to_string(&path)?;
278        let profile_value: Value = serde_json::from_str(&profile).unwrap();
279        DriftProfile::from_value(profile_value)
280    }
281
282    pub fn version(&self) -> Option<String> {
283        match self {
284            DriftProfile::Spc(profile) => Some(profile.config.version.clone()),
285            DriftProfile::Psi(profile) => Some(profile.config.version.clone()),
286            DriftProfile::Custom(profile) => Some(profile.config.version.clone()),
287            DriftProfile::GenAI(profile) => Some(profile.config.version.clone()),
288        }
289    }
290
291    pub fn identifier(&self) -> String {
292        match self {
293            DriftProfile::Spc(profile) => {
294                format!(
295                    "{}/{}/v{}/spc",
296                    profile.config.space, profile.config.name, profile.config.version
297                )
298            }
299            DriftProfile::Psi(profile) => {
300                format!(
301                    "{}/{}/v{}/psi",
302                    profile.config.space, profile.config.name, profile.config.version
303                )
304            }
305            DriftProfile::Custom(profile) => {
306                format!(
307                    "{}/{}/v{}/custom",
308                    profile.config.space, profile.config.name, profile.config.version
309                )
310            }
311            DriftProfile::GenAI(profile) => {
312                format!(
313                    "{}/{}/v{}/genai",
314                    profile.config.space, profile.config.name, profile.config.version
315                )
316            }
317        }
318    }
319
320    pub fn uid(&self) -> &String {
321        match self {
322            DriftProfile::Spc(profile) => &profile.config.uid,
323            DriftProfile::Psi(profile) => &profile.config.uid,
324            DriftProfile::Custom(profile) => &profile.config.uid,
325            DriftProfile::GenAI(profile) => &profile.config.uid,
326        }
327    }
328}
329
330impl Default for DriftProfile {
331    fn default() -> Self {
332        DriftProfile::Spc(SpcDriftProfile::default())
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use tempfile::TempDir;
340
341    #[test]
342    fn test_drift_type_from_str_base() {
343        assert_eq!(DriftType::from_str("SPC").unwrap(), DriftType::Spc);
344        assert_eq!(DriftType::from_str("PSI").unwrap(), DriftType::Psi);
345        assert_eq!(DriftType::from_str("CUSTOM").unwrap(), DriftType::Custom);
346        assert!(DriftType::from_str("INVALID").is_err());
347    }
348
349    #[test]
350    fn test_drift_type_value_base() {
351        assert_eq!(DriftType::Spc.to_string(), "Spc");
352        assert_eq!(DriftType::Psi.to_string(), "Psi");
353        assert_eq!(DriftType::Custom.to_string(), "Custom");
354    }
355
356    #[test]
357    fn test_drift_profile_enum() {
358        let profile = DriftProfile::Spc(SpcDriftProfile::default());
359
360        // save to temppath
361        let temp_dir = TempDir::new().unwrap();
362        let path = temp_dir.path().join("profile.json");
363
364        profile.save_to_json(Some(path.clone())).unwrap();
365
366        // assert path exists
367        assert!(path.exists());
368
369        // load from path
370        let loaded_profile = DriftProfile::load_from_json(path).unwrap();
371
372        // assert profile is the same
373        assert_eq!(profile, loaded_profile);
374    }
375}