scouter_types/
drift.rs

1use crate::custom::CustomDriftProfile;
2use crate::error::ProfileError;
3use crate::psi::PsiDriftProfile;
4use crate::spc::SpcDriftProfile;
5use crate::util::ProfileBaseArgs;
6use crate::{AlertDispatchConfig, ProfileArgs};
7use crate::{FileName, ProfileFuncs};
8use pyo3::prelude::*;
9use pyo3::IntoPyObjectExt;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::fmt::Display;
13use std::path::PathBuf;
14use std::str::FromStr;
15use strum_macros::EnumIter;
16#[pyclass(eq)]
17#[derive(Debug, EnumIter, PartialEq, Serialize, Deserialize, Clone, Default, Eq, Hash)]
18pub enum DriftType {
19    #[default]
20    Spc,
21    Psi,
22    Custom,
23}
24
25#[pymethods]
26impl DriftType {
27    #[staticmethod]
28    pub fn from_value(value: &str) -> Option<DriftType> {
29        match value.to_lowercase().as_str() {
30            "spc" => Some(DriftType::Spc),
31            "psi" => Some(DriftType::Psi),
32            "custom" => Some(DriftType::Custom),
33            _ => None,
34        }
35    }
36
37    #[getter]
38    pub fn to_string(&self) -> &str {
39        match self {
40            DriftType::Spc => "Spc",
41            DriftType::Psi => "Psi",
42            DriftType::Custom => "Custom",
43        }
44    }
45}
46
47impl FromStr for DriftType {
48    type Err = ProfileError;
49
50    fn from_str(value: &str) -> Result<Self, Self::Err> {
51        match value.to_lowercase().as_str() {
52            "spc" => Ok(DriftType::Spc),
53            "psi" => Ok(DriftType::Psi),
54            "custom" => Ok(DriftType::Custom),
55            _ => Err(ProfileError::InvalidDriftTypeError),
56        }
57    }
58}
59
60impl Display for DriftType {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match self {
63            DriftType::Spc => write!(f, "Spc"),
64            DriftType::Psi => write!(f, "Psi"),
65            DriftType::Custom => write!(f, "Custom"),
66        }
67    }
68}
69
70pub struct DriftArgs {
71    pub name: String,
72    pub space: String,
73    pub version: String,
74    pub dispatch_config: AlertDispatchConfig,
75}
76
77// Generic enum to be used on scouter server
78#[pyclass]
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub enum DriftProfile {
81    Spc(SpcDriftProfile),
82    Psi(PsiDriftProfile),
83    Custom(CustomDriftProfile),
84}
85
86#[pymethods]
87impl DriftProfile {
88    #[getter]
89    pub fn profile<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, ProfileError> {
90        match self {
91            DriftProfile::Spc(profile) => Ok(profile.clone().into_bound_py_any(py)?),
92            DriftProfile::Psi(profile) => Ok(profile.clone().into_bound_py_any(py)?),
93            DriftProfile::Custom(profile) => Ok(profile.clone().into_bound_py_any(py)?),
94        }
95    }
96}
97
98impl DriftProfile {
99    /// Create a new DriftProfile from a DriftType and a profile string
100    /// This function will map the drift type to the correct profile type to load
101    ///
102    /// # Arguments
103    ///
104    /// * `drift_type` - DriftType enum
105    /// * `profile` - Profile string
106    ///
107    /// # Returns
108    ///
109    /// * `Result<Self>` - Result of DriftProfile
110    pub fn from_str(drift_type: DriftType, profile: String) -> Result<Self, ProfileError> {
111        match drift_type {
112            DriftType::Spc => {
113                let profile = serde_json::from_str(&profile)?;
114                Ok(DriftProfile::Spc(profile))
115            }
116            DriftType::Psi => {
117                let profile = serde_json::from_str(&profile)?;
118                Ok(DriftProfile::Psi(profile))
119            }
120            DriftType::Custom => {
121                let profile = serde_json::from_str(&profile)?;
122                Ok(DriftProfile::Custom(profile))
123            }
124        }
125    }
126
127    /// Get the base arguments for a drift profile
128    pub fn get_base_args(&self) -> ProfileArgs {
129        match self {
130            DriftProfile::Spc(profile) => profile.get_base_args(),
131            DriftProfile::Psi(profile) => profile.get_base_args(),
132            DriftProfile::Custom(profile) => profile.get_base_args(),
133        }
134    }
135
136    pub fn to_value(&self) -> serde_json::Value {
137        match self {
138            DriftProfile::Spc(profile) => profile.to_value(),
139            DriftProfile::Psi(profile) => profile.to_value(),
140            DriftProfile::Custom(profile) => profile.to_value(),
141        }
142    }
143
144    /// Create a new DriftProfile from a value (this is used by scouter-server)
145    /// This function will map the drift type to the correct profile type to load
146    ///
147    /// # Arguments
148    ///
149    /// * `body` - Request body
150    /// * `drift_type` - Drift type string
151    ///
152    pub fn from_value(body: serde_json::Value) -> Result<Self, ProfileError> {
153        let drift_type = body["config"]["drift_type"].as_str().unwrap();
154
155        let drift_type = DriftType::from_str(drift_type)?;
156
157        match drift_type {
158            DriftType::Spc => {
159                let profile = serde_json::from_value(body)?;
160                Ok(DriftProfile::Spc(profile))
161            }
162            DriftType::Psi => {
163                let profile = serde_json::from_value(body)?;
164                Ok(DriftProfile::Psi(profile))
165            }
166            DriftType::Custom => {
167                let profile = serde_json::from_value(body)?;
168                Ok(DriftProfile::Custom(profile))
169            }
170        }
171    }
172
173    pub fn from_python(
174        drift_type: DriftType,
175        profile: &Bound<'_, PyAny>,
176    ) -> Result<Self, ProfileError> {
177        match drift_type {
178            DriftType::Spc => {
179                let profile = profile.extract::<SpcDriftProfile>()?;
180                Ok(DriftProfile::Spc(profile))
181            }
182            DriftType::Psi => {
183                let profile = profile.extract::<PsiDriftProfile>()?;
184                Ok(DriftProfile::Psi(profile))
185            }
186            DriftType::Custom => {
187                let profile = profile.extract::<CustomDriftProfile>()?;
188                Ok(DriftProfile::Custom(profile))
189            }
190        }
191    }
192
193    pub fn get_spc_profile(&self) -> Result<&SpcDriftProfile, ProfileError> {
194        match self {
195            DriftProfile::Spc(profile) => Ok(profile),
196            _ => Err(ProfileError::InvalidDriftTypeError),
197        }
198    }
199
200    pub fn get_psi_profile(&self) -> Result<&PsiDriftProfile, ProfileError> {
201        match self {
202            DriftProfile::Psi(profile) => Ok(profile),
203            _ => Err(ProfileError::InvalidDriftTypeError),
204        }
205    }
206
207    pub fn drift_type(&self) -> DriftType {
208        match self {
209            DriftProfile::Spc(_) => DriftType::Spc,
210            DriftProfile::Psi(_) => DriftType::Psi,
211            DriftProfile::Custom(_) => DriftType::Custom,
212        }
213    }
214
215    pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
216        Ok(ProfileFuncs::save_to_json(
217            self,
218            path,
219            FileName::DriftProfile.to_str(),
220        )?)
221    }
222
223    pub fn load_from_json(path: PathBuf) -> Result<Self, ProfileError> {
224        let file = std::fs::read_to_string(&path)?;
225        Ok(serde_json::from_str(&file)?)
226    }
227
228    /// load a profile into the DriftProfile enum from path
229    ///
230    /// # Arguments
231    /// * `path` - Path to the profile
232    ///
233    /// # Returns
234    /// * `Result<Self>` - Result of DriftProfile
235    pub fn from_profile_path(path: PathBuf) -> Result<Self, ProfileError> {
236        let profile = std::fs::read_to_string(&path)?;
237        let profile_value: Value = serde_json::from_str(&profile).unwrap();
238        DriftProfile::from_value(profile_value)
239    }
240}
241
242impl Default for DriftProfile {
243    fn default() -> Self {
244        DriftProfile::Spc(SpcDriftProfile::default())
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use tempfile::TempDir;
252
253    #[test]
254    fn test_drift_type_from_str_base() {
255        assert_eq!(DriftType::from_str("SPC").unwrap(), DriftType::Spc);
256        assert_eq!(DriftType::from_str("PSI").unwrap(), DriftType::Psi);
257        assert_eq!(DriftType::from_str("CUSTOM").unwrap(), DriftType::Custom);
258        assert!(DriftType::from_str("INVALID").is_err());
259    }
260
261    #[test]
262    fn test_drift_type_value_base() {
263        assert_eq!(DriftType::Spc.to_string(), "Spc");
264        assert_eq!(DriftType::Psi.to_string(), "Psi");
265        assert_eq!(DriftType::Custom.to_string(), "Custom");
266    }
267
268    #[test]
269    fn test_drift_profile_enum() {
270        let profile = DriftProfile::Spc(SpcDriftProfile::default());
271
272        // save to temppath
273        let temp_dir = TempDir::new().unwrap();
274        let path = temp_dir.path().join("profile.json");
275
276        profile.save_to_json(Some(path.clone())).unwrap();
277
278        // assert path exists
279        assert!(path.exists());
280
281        // load from path
282        let loaded_profile = DriftProfile::load_from_json(path).unwrap();
283
284        // assert profile is the same
285        assert_eq!(profile, loaded_profile);
286    }
287}