scouter_types/
drift.rs

1use crate::custom::CustomDriftProfile;
2use crate::error::ProfileError;
3use crate::llm::profile::LLMDriftProfile;
4use crate::psi::PsiDriftProfile;
5use crate::spc::SpcDriftProfile;
6use crate::util::ProfileBaseArgs;
7use crate::{AlertDispatchConfig, ProfileArgs};
8use crate::{FileName, PyHelperFuncs};
9use pyo3::prelude::*;
10use pyo3::IntoPyObjectExt;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::fmt::Display;
14use std::path::PathBuf;
15use std::str::FromStr;
16use strum_macros::EnumIter;
17#[pyclass(eq)]
18#[derive(Debug, EnumIter, PartialEq, Serialize, Deserialize, Clone, Default, Eq, Hash)]
19pub enum DriftType {
20    #[default]
21    Spc,
22    Psi,
23    Custom,
24    LLM,
25}
26
27#[pymethods]
28impl DriftType {
29    #[staticmethod]
30    pub fn from_value(value: &str) -> Option<DriftType> {
31        match value.to_lowercase().as_str() {
32            "spc" => Some(DriftType::Spc),
33            "psi" => Some(DriftType::Psi),
34            "custom" => Some(DriftType::Custom),
35            "llm" => Some(DriftType::LLM),
36            _ => None,
37        }
38    }
39
40    #[getter]
41    pub fn to_string(&self) -> &str {
42        match self {
43            DriftType::Spc => "Spc",
44            DriftType::Psi => "Psi",
45            DriftType::Custom => "Custom",
46            DriftType::LLM => "LLM",
47        }
48    }
49}
50
51impl FromStr for DriftType {
52    type Err = ProfileError;
53
54    fn from_str(value: &str) -> Result<Self, Self::Err> {
55        match value.to_lowercase().as_str() {
56            "spc" => Ok(DriftType::Spc),
57            "psi" => Ok(DriftType::Psi),
58            "custom" => Ok(DriftType::Custom),
59            "llm" => Ok(DriftType::LLM),
60            _ => Err(ProfileError::InvalidDriftTypeError),
61        }
62    }
63}
64
65impl Display for DriftType {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            DriftType::Spc => write!(f, "Spc"),
69            DriftType::Psi => write!(f, "Psi"),
70            DriftType::Custom => write!(f, "Custom"),
71            DriftType::LLM => write!(f, "LLM"),
72        }
73    }
74}
75
76pub struct DriftArgs {
77    pub name: String,
78    pub space: String,
79    pub version: String,
80    pub dispatch_config: AlertDispatchConfig,
81}
82
83#[pyclass]
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85pub enum DriftProfile {
86    Spc(SpcDriftProfile),
87    Psi(PsiDriftProfile),
88    Custom(CustomDriftProfile),
89    LLM(LLMDriftProfile),
90}
91
92#[pymethods]
93impl DriftProfile {
94    #[getter]
95    pub fn profile<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, ProfileError> {
96        match self {
97            DriftProfile::Spc(profile) => Ok(profile.clone().into_bound_py_any(py)?),
98            DriftProfile::Psi(profile) => Ok(profile.clone().into_bound_py_any(py)?),
99            DriftProfile::Custom(profile) => Ok(profile.clone().into_bound_py_any(py)?),
100            DriftProfile::LLM(profile) => Ok(profile.clone().into_bound_py_any(py)?),
101        }
102    }
103}
104
105impl DriftProfile {
106    /// Create a new DriftProfile from a DriftType and a profile string
107    /// This function will map the drift type to the correct profile type to load
108    ///
109    /// # Arguments
110    ///
111    /// * `drift_type` - DriftType enum
112    /// * `profile` - Profile string
113    ///
114    /// # Returns
115    ///
116    /// * `Result<Self>` - Result of DriftProfile
117    pub fn from_str(drift_type: DriftType, profile: String) -> Result<Self, ProfileError> {
118        match drift_type {
119            DriftType::Spc => {
120                let profile = serde_json::from_str(&profile)?;
121                Ok(DriftProfile::Spc(profile))
122            }
123            DriftType::Psi => {
124                let profile = serde_json::from_str(&profile)?;
125                Ok(DriftProfile::Psi(profile))
126            }
127            DriftType::Custom => {
128                let profile = serde_json::from_str(&profile)?;
129                Ok(DriftProfile::Custom(profile))
130            }
131            DriftType::LLM => {
132                let profile = serde_json::from_str(&profile)?;
133                Ok(DriftProfile::LLM(profile))
134            }
135        }
136    }
137
138    /// Get the base arguments for a drift profile
139    pub fn get_base_args(&self) -> ProfileArgs {
140        match self {
141            DriftProfile::Spc(profile) => profile.get_base_args(),
142            DriftProfile::Psi(profile) => profile.get_base_args(),
143            DriftProfile::Custom(profile) => profile.get_base_args(),
144            DriftProfile::LLM(profile) => profile.get_base_args(),
145        }
146    }
147
148    pub fn to_value(&self) -> serde_json::Value {
149        match self {
150            DriftProfile::Spc(profile) => profile.to_value(),
151            DriftProfile::Psi(profile) => profile.to_value(),
152            DriftProfile::Custom(profile) => profile.to_value(),
153            DriftProfile::LLM(profile) => profile.to_value(),
154        }
155    }
156
157    /// Create a new DriftProfile from a value (this is used by scouter-server)
158    /// This function will map the drift type to the correct profile type to load
159    ///
160    /// # Arguments
161    ///
162    /// * `body` - Request body
163    /// * `drift_type` - Drift type string
164    ///
165    pub fn from_value(body: serde_json::Value) -> Result<Self, ProfileError> {
166        let drift_type = body["config"]["drift_type"].as_str().unwrap();
167
168        let drift_type = DriftType::from_str(drift_type)?;
169
170        match drift_type {
171            DriftType::Spc => {
172                let profile = serde_json::from_value(body)?;
173                Ok(DriftProfile::Spc(profile))
174            }
175            DriftType::Psi => {
176                let profile = serde_json::from_value(body)?;
177                Ok(DriftProfile::Psi(profile))
178            }
179            DriftType::Custom => {
180                let profile = serde_json::from_value(body)?;
181                Ok(DriftProfile::Custom(profile))
182            }
183            DriftType::LLM => {
184                let profile = serde_json::from_value(body)?;
185                Ok(DriftProfile::LLM(profile))
186            }
187        }
188    }
189
190    pub fn from_python(
191        drift_type: DriftType,
192        profile: &Bound<'_, PyAny>,
193    ) -> Result<Self, ProfileError> {
194        match drift_type {
195            DriftType::Spc => {
196                let profile = profile.extract::<SpcDriftProfile>()?;
197                Ok(DriftProfile::Spc(profile))
198            }
199            DriftType::Psi => {
200                let profile = profile.extract::<PsiDriftProfile>()?;
201                Ok(DriftProfile::Psi(profile))
202            }
203            DriftType::Custom => {
204                let profile = profile.extract::<CustomDriftProfile>()?;
205                Ok(DriftProfile::Custom(profile))
206            }
207            DriftType::LLM => {
208                let profile = profile.extract::<LLMDriftProfile>()?;
209                Ok(DriftProfile::LLM(profile))
210            }
211        }
212    }
213
214    pub fn get_spc_profile(&self) -> Result<&SpcDriftProfile, ProfileError> {
215        match self {
216            DriftProfile::Spc(profile) => Ok(profile),
217            _ => Err(ProfileError::InvalidDriftTypeError),
218        }
219    }
220
221    pub fn get_psi_profile(&self) -> Result<&PsiDriftProfile, ProfileError> {
222        match self {
223            DriftProfile::Psi(profile) => Ok(profile),
224            _ => Err(ProfileError::InvalidDriftTypeError),
225        }
226    }
227
228    pub fn get_llm_profile(&self) -> Result<&LLMDriftProfile, ProfileError> {
229        match self {
230            DriftProfile::LLM(profile) => Ok(profile),
231            _ => Err(ProfileError::InvalidDriftTypeError),
232        }
233    }
234
235    pub fn drift_type(&self) -> DriftType {
236        match self {
237            DriftProfile::Spc(_) => DriftType::Spc,
238            DriftProfile::Psi(_) => DriftType::Psi,
239            DriftProfile::Custom(_) => DriftType::Custom,
240            DriftProfile::LLM(_) => DriftType::LLM,
241        }
242    }
243
244    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
245        Ok(PyHelperFuncs::save_to_json(
246            self,
247            path,
248            FileName::DriftProfile.to_str(),
249        )?)
250    }
251
252    pub fn load_from_json(path: PathBuf) -> Result<Self, ProfileError> {
253        let file = std::fs::read_to_string(&path)?;
254        Ok(serde_json::from_str(&file)?)
255    }
256
257    /// load a profile into the DriftProfile enum from path
258    ///
259    /// # Arguments
260    /// * `path` - Path to the profile
261    ///
262    /// # Returns
263    /// * `Result<Self>` - Result of DriftProfile
264    pub fn from_profile_path(path: PathBuf) -> Result<Self, ProfileError> {
265        let profile = std::fs::read_to_string(&path)?;
266        let profile_value: Value = serde_json::from_str(&profile).unwrap();
267        DriftProfile::from_value(profile_value)
268    }
269
270    pub fn version(&self) -> Option<String> {
271        match self {
272            DriftProfile::Spc(profile) => Some(profile.config.version.clone()),
273            DriftProfile::Psi(profile) => Some(profile.config.version.clone()),
274            DriftProfile::Custom(profile) => Some(profile.config.version.clone()),
275            DriftProfile::LLM(profile) => Some(profile.config.version.clone()),
276        }
277    }
278
279    pub fn identifier(&self) -> String {
280        match self {
281            DriftProfile::Spc(profile) => {
282                format!(
283                    "{}/{}/v{}/spc",
284                    profile.config.space, profile.config.name, profile.config.version
285                )
286            }
287            DriftProfile::Psi(profile) => {
288                format!(
289                    "{}/{}/v{}/psi",
290                    profile.config.space, profile.config.name, profile.config.version
291                )
292            }
293            DriftProfile::Custom(profile) => {
294                format!(
295                    "{}/{}/v{}/custom",
296                    profile.config.space, profile.config.name, profile.config.version
297                )
298            }
299            DriftProfile::LLM(profile) => {
300                format!(
301                    "{}/{}/v{}/llm",
302                    profile.config.space, profile.config.name, profile.config.version
303                )
304            }
305        }
306    }
307}
308
309impl Default for DriftProfile {
310    fn default() -> Self {
311        DriftProfile::Spc(SpcDriftProfile::default())
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use tempfile::TempDir;
319
320    #[test]
321    fn test_drift_type_from_str_base() {
322        assert_eq!(DriftType::from_str("SPC").unwrap(), DriftType::Spc);
323        assert_eq!(DriftType::from_str("PSI").unwrap(), DriftType::Psi);
324        assert_eq!(DriftType::from_str("CUSTOM").unwrap(), DriftType::Custom);
325        assert!(DriftType::from_str("INVALID").is_err());
326    }
327
328    #[test]
329    fn test_drift_type_value_base() {
330        assert_eq!(DriftType::Spc.to_string(), "Spc");
331        assert_eq!(DriftType::Psi.to_string(), "Psi");
332        assert_eq!(DriftType::Custom.to_string(), "Custom");
333    }
334
335    #[test]
336    fn test_drift_profile_enum() {
337        let profile = DriftProfile::Spc(SpcDriftProfile::default());
338
339        // save to temppath
340        let temp_dir = TempDir::new().unwrap();
341        let path = temp_dir.path().join("profile.json");
342
343        profile.save_to_json(Some(path.clone())).unwrap();
344
345        // assert path exists
346        assert!(path.exists());
347
348        // load from path
349        let loaded_profile = DriftProfile::load_from_json(path).unwrap();
350
351        // assert profile is the same
352        assert_eq!(profile, loaded_profile);
353    }
354}