1use crate::custom::CustomDriftProfile;
2use crate::error::ProfileError;
3use crate::psi::PsiDriftProfile;
4use crate::spc::SpcDriftProfile;
5use crate::util::ProfileBaseArgs;
6use crate::{AlertDispatchConfig, ProfileArgs};
7use crate::{FileName, ProfileFuncs};
8use pyo3::prelude::*;
9use pyo3::IntoPyObjectExt;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::fmt::Display;
13use std::path::PathBuf;
14use std::str::FromStr;
15use strum_macros::EnumIter;
16#[pyclass(eq)]
17#[derive(Debug, EnumIter, PartialEq, Serialize, Deserialize, Clone, Default, Eq, Hash)]
18pub enum DriftType {
19 #[default]
20 Spc,
21 Psi,
22 Custom,
23}
24
25#[pymethods]
26impl DriftType {
27 #[staticmethod]
28 pub fn from_value(value: &str) -> Option<DriftType> {
29 match value.to_lowercase().as_str() {
30 "spc" => Some(DriftType::Spc),
31 "psi" => Some(DriftType::Psi),
32 "custom" => Some(DriftType::Custom),
33 _ => None,
34 }
35 }
36
37 #[getter]
38 pub fn to_string(&self) -> &str {
39 match self {
40 DriftType::Spc => "Spc",
41 DriftType::Psi => "Psi",
42 DriftType::Custom => "Custom",
43 }
44 }
45}
46
47impl FromStr for DriftType {
48 type Err = ProfileError;
49
50 fn from_str(value: &str) -> Result<Self, Self::Err> {
51 match value.to_lowercase().as_str() {
52 "spc" => Ok(DriftType::Spc),
53 "psi" => Ok(DriftType::Psi),
54 "custom" => Ok(DriftType::Custom),
55 _ => Err(ProfileError::InvalidDriftTypeError),
56 }
57 }
58}
59
60impl Display for DriftType {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 DriftType::Spc => write!(f, "Spc"),
64 DriftType::Psi => write!(f, "Psi"),
65 DriftType::Custom => write!(f, "Custom"),
66 }
67 }
68}
69
70pub struct DriftArgs {
71 pub name: String,
72 pub space: String,
73 pub version: String,
74 pub dispatch_config: AlertDispatchConfig,
75}
76
77#[pyclass]
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub enum DriftProfile {
81 Spc(SpcDriftProfile),
82 Psi(PsiDriftProfile),
83 Custom(CustomDriftProfile),
84}
85
86#[pymethods]
87impl DriftProfile {
88 #[getter]
89 pub fn profile<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, ProfileError> {
90 match self {
91 DriftProfile::Spc(profile) => Ok(profile.clone().into_bound_py_any(py)?),
92 DriftProfile::Psi(profile) => Ok(profile.clone().into_bound_py_any(py)?),
93 DriftProfile::Custom(profile) => Ok(profile.clone().into_bound_py_any(py)?),
94 }
95 }
96}
97
98impl DriftProfile {
99 pub fn from_str(drift_type: DriftType, profile: String) -> Result<Self, ProfileError> {
111 match drift_type {
112 DriftType::Spc => {
113 let profile = serde_json::from_str(&profile)?;
114 Ok(DriftProfile::Spc(profile))
115 }
116 DriftType::Psi => {
117 let profile = serde_json::from_str(&profile)?;
118 Ok(DriftProfile::Psi(profile))
119 }
120 DriftType::Custom => {
121 let profile = serde_json::from_str(&profile)?;
122 Ok(DriftProfile::Custom(profile))
123 }
124 }
125 }
126
127 pub fn get_base_args(&self) -> ProfileArgs {
129 match self {
130 DriftProfile::Spc(profile) => profile.get_base_args(),
131 DriftProfile::Psi(profile) => profile.get_base_args(),
132 DriftProfile::Custom(profile) => profile.get_base_args(),
133 }
134 }
135
136 pub fn to_value(&self) -> serde_json::Value {
137 match self {
138 DriftProfile::Spc(profile) => profile.to_value(),
139 DriftProfile::Psi(profile) => profile.to_value(),
140 DriftProfile::Custom(profile) => profile.to_value(),
141 }
142 }
143
144 pub fn from_value(body: serde_json::Value) -> Result<Self, ProfileError> {
153 let drift_type = body["config"]["drift_type"].as_str().unwrap();
154
155 let drift_type = DriftType::from_str(drift_type)?;
156
157 match drift_type {
158 DriftType::Spc => {
159 let profile = serde_json::from_value(body)?;
160 Ok(DriftProfile::Spc(profile))
161 }
162 DriftType::Psi => {
163 let profile = serde_json::from_value(body)?;
164 Ok(DriftProfile::Psi(profile))
165 }
166 DriftType::Custom => {
167 let profile = serde_json::from_value(body)?;
168 Ok(DriftProfile::Custom(profile))
169 }
170 }
171 }
172
173 pub fn from_python(
174 drift_type: DriftType,
175 profile: &Bound<'_, PyAny>,
176 ) -> Result<Self, ProfileError> {
177 match drift_type {
178 DriftType::Spc => {
179 let profile = profile.extract::<SpcDriftProfile>()?;
180 Ok(DriftProfile::Spc(profile))
181 }
182 DriftType::Psi => {
183 let profile = profile.extract::<PsiDriftProfile>()?;
184 Ok(DriftProfile::Psi(profile))
185 }
186 DriftType::Custom => {
187 let profile = profile.extract::<CustomDriftProfile>()?;
188 Ok(DriftProfile::Custom(profile))
189 }
190 }
191 }
192
193 pub fn get_spc_profile(&self) -> Result<&SpcDriftProfile, ProfileError> {
194 match self {
195 DriftProfile::Spc(profile) => Ok(profile),
196 _ => Err(ProfileError::InvalidDriftTypeError),
197 }
198 }
199
200 pub fn get_psi_profile(&self) -> Result<&PsiDriftProfile, ProfileError> {
201 match self {
202 DriftProfile::Psi(profile) => Ok(profile),
203 _ => Err(ProfileError::InvalidDriftTypeError),
204 }
205 }
206
207 pub fn drift_type(&self) -> DriftType {
208 match self {
209 DriftProfile::Spc(_) => DriftType::Spc,
210 DriftProfile::Psi(_) => DriftType::Psi,
211 DriftProfile::Custom(_) => DriftType::Custom,
212 }
213 }
214
215 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
216 Ok(ProfileFuncs::save_to_json(
217 self,
218 path,
219 FileName::DriftProfile.to_str(),
220 )?)
221 }
222
223 pub fn load_from_json(path: PathBuf) -> Result<Self, ProfileError> {
224 let file = std::fs::read_to_string(&path)?;
225 Ok(serde_json::from_str(&file)?)
226 }
227
228 pub fn from_profile_path(path: PathBuf) -> Result<Self, ProfileError> {
236 let profile = std::fs::read_to_string(&path)?;
237 let profile_value: Value = serde_json::from_str(&profile).unwrap();
238 DriftProfile::from_value(profile_value)
239 }
240}
241
242impl Default for DriftProfile {
243 fn default() -> Self {
244 DriftProfile::Spc(SpcDriftProfile::default())
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use tempfile::TempDir;
252
253 #[test]
254 fn test_drift_type_from_str_base() {
255 assert_eq!(DriftType::from_str("SPC").unwrap(), DriftType::Spc);
256 assert_eq!(DriftType::from_str("PSI").unwrap(), DriftType::Psi);
257 assert_eq!(DriftType::from_str("CUSTOM").unwrap(), DriftType::Custom);
258 assert!(DriftType::from_str("INVALID").is_err());
259 }
260
261 #[test]
262 fn test_drift_type_value_base() {
263 assert_eq!(DriftType::Spc.to_string(), "Spc");
264 assert_eq!(DriftType::Psi.to_string(), "Psi");
265 assert_eq!(DriftType::Custom.to_string(), "Custom");
266 }
267
268 #[test]
269 fn test_drift_profile_enum() {
270 let profile = DriftProfile::Spc(SpcDriftProfile::default());
271
272 let temp_dir = TempDir::new().unwrap();
274 let path = temp_dir.path().join("profile.json");
275
276 profile.save_to_json(Some(path.clone())).unwrap();
277
278 assert!(path.exists());
280
281 let loaded_profile = DriftProfile::load_from_json(path).unwrap();
283
284 assert_eq!(profile, loaded_profile);
286 }
287}