1use crate::custom::CustomDriftProfile;
2use crate::error::ProfileError;
3use crate::llm::profile::LLMDriftProfile;
4use crate::psi::PsiDriftProfile;
5use crate::spc::SpcDriftProfile;
6use crate::util::ProfileBaseArgs;
7use crate::{AlertDispatchConfig, ProfileArgs};
8use crate::{FileName, PyHelperFuncs};
9use pyo3::prelude::*;
10use pyo3::IntoPyObjectExt;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::fmt::Display;
14use std::path::PathBuf;
15use std::str::FromStr;
16use strum_macros::EnumIter;
17#[pyclass(eq)]
18#[derive(Debug, EnumIter, PartialEq, Serialize, Deserialize, Clone, Default, Eq, Hash)]
19pub enum DriftType {
20 #[default]
21 Spc,
22 Psi,
23 Custom,
24 LLM,
25}
26
27#[pymethods]
28impl DriftType {
29 #[staticmethod]
30 pub fn from_value(value: &str) -> Option<DriftType> {
31 match value.to_lowercase().as_str() {
32 "spc" => Some(DriftType::Spc),
33 "psi" => Some(DriftType::Psi),
34 "custom" => Some(DriftType::Custom),
35 "llm" => Some(DriftType::LLM),
36 _ => None,
37 }
38 }
39
40 #[getter]
41 pub fn to_string(&self) -> &str {
42 match self {
43 DriftType::Spc => "Spc",
44 DriftType::Psi => "Psi",
45 DriftType::Custom => "Custom",
46 DriftType::LLM => "LLM",
47 }
48 }
49}
50
51impl FromStr for DriftType {
52 type Err = ProfileError;
53
54 fn from_str(value: &str) -> Result<Self, Self::Err> {
55 match value.to_lowercase().as_str() {
56 "spc" => Ok(DriftType::Spc),
57 "psi" => Ok(DriftType::Psi),
58 "custom" => Ok(DriftType::Custom),
59 "llm" => Ok(DriftType::LLM),
60 _ => Err(ProfileError::InvalidDriftTypeError),
61 }
62 }
63}
64
65impl Display for DriftType {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 DriftType::Spc => write!(f, "Spc"),
69 DriftType::Psi => write!(f, "Psi"),
70 DriftType::Custom => write!(f, "Custom"),
71 DriftType::LLM => write!(f, "LLM"),
72 }
73 }
74}
75
76pub struct DriftArgs {
77 pub name: String,
78 pub space: String,
79 pub version: String,
80 pub dispatch_config: AlertDispatchConfig,
81}
82
83#[pyclass]
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85pub enum DriftProfile {
86 Spc(SpcDriftProfile),
87 Psi(PsiDriftProfile),
88 Custom(CustomDriftProfile),
89 LLM(LLMDriftProfile),
90}
91
92#[pymethods]
93impl DriftProfile {
94 #[getter]
95 pub fn profile<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, ProfileError> {
96 match self {
97 DriftProfile::Spc(profile) => Ok(profile.clone().into_bound_py_any(py)?),
98 DriftProfile::Psi(profile) => Ok(profile.clone().into_bound_py_any(py)?),
99 DriftProfile::Custom(profile) => Ok(profile.clone().into_bound_py_any(py)?),
100 DriftProfile::LLM(profile) => Ok(profile.clone().into_bound_py_any(py)?),
101 }
102 }
103}
104
105impl DriftProfile {
106 pub fn from_str(drift_type: DriftType, profile: String) -> Result<Self, ProfileError> {
118 match drift_type {
119 DriftType::Spc => {
120 let profile = serde_json::from_str(&profile)?;
121 Ok(DriftProfile::Spc(profile))
122 }
123 DriftType::Psi => {
124 let profile = serde_json::from_str(&profile)?;
125 Ok(DriftProfile::Psi(profile))
126 }
127 DriftType::Custom => {
128 let profile = serde_json::from_str(&profile)?;
129 Ok(DriftProfile::Custom(profile))
130 }
131 DriftType::LLM => {
132 let profile = serde_json::from_str(&profile)?;
133 Ok(DriftProfile::LLM(profile))
134 }
135 }
136 }
137
138 pub fn get_base_args(&self) -> ProfileArgs {
140 match self {
141 DriftProfile::Spc(profile) => profile.get_base_args(),
142 DriftProfile::Psi(profile) => profile.get_base_args(),
143 DriftProfile::Custom(profile) => profile.get_base_args(),
144 DriftProfile::LLM(profile) => profile.get_base_args(),
145 }
146 }
147
148 pub fn to_value(&self) -> serde_json::Value {
149 match self {
150 DriftProfile::Spc(profile) => profile.to_value(),
151 DriftProfile::Psi(profile) => profile.to_value(),
152 DriftProfile::Custom(profile) => profile.to_value(),
153 DriftProfile::LLM(profile) => profile.to_value(),
154 }
155 }
156
157 pub fn from_value(body: serde_json::Value) -> Result<Self, ProfileError> {
166 let drift_type = body["config"]["drift_type"].as_str().unwrap();
167
168 let drift_type = DriftType::from_str(drift_type)?;
169
170 match drift_type {
171 DriftType::Spc => {
172 let profile = serde_json::from_value(body)?;
173 Ok(DriftProfile::Spc(profile))
174 }
175 DriftType::Psi => {
176 let profile = serde_json::from_value(body)?;
177 Ok(DriftProfile::Psi(profile))
178 }
179 DriftType::Custom => {
180 let profile = serde_json::from_value(body)?;
181 Ok(DriftProfile::Custom(profile))
182 }
183 DriftType::LLM => {
184 let profile = serde_json::from_value(body)?;
185 Ok(DriftProfile::LLM(profile))
186 }
187 }
188 }
189
190 pub fn from_python(
191 drift_type: DriftType,
192 profile: &Bound<'_, PyAny>,
193 ) -> Result<Self, ProfileError> {
194 match drift_type {
195 DriftType::Spc => {
196 let profile = profile.extract::<SpcDriftProfile>()?;
197 Ok(DriftProfile::Spc(profile))
198 }
199 DriftType::Psi => {
200 let profile = profile.extract::<PsiDriftProfile>()?;
201 Ok(DriftProfile::Psi(profile))
202 }
203 DriftType::Custom => {
204 let profile = profile.extract::<CustomDriftProfile>()?;
205 Ok(DriftProfile::Custom(profile))
206 }
207 DriftType::LLM => {
208 let profile = profile.extract::<LLMDriftProfile>()?;
209 Ok(DriftProfile::LLM(profile))
210 }
211 }
212 }
213
214 pub fn get_spc_profile(&self) -> Result<&SpcDriftProfile, ProfileError> {
215 match self {
216 DriftProfile::Spc(profile) => Ok(profile),
217 _ => Err(ProfileError::InvalidDriftTypeError),
218 }
219 }
220
221 pub fn get_psi_profile(&self) -> Result<&PsiDriftProfile, ProfileError> {
222 match self {
223 DriftProfile::Psi(profile) => Ok(profile),
224 _ => Err(ProfileError::InvalidDriftTypeError),
225 }
226 }
227
228 pub fn get_llm_profile(&self) -> Result<&LLMDriftProfile, ProfileError> {
229 match self {
230 DriftProfile::LLM(profile) => Ok(profile),
231 _ => Err(ProfileError::InvalidDriftTypeError),
232 }
233 }
234
235 pub fn drift_type(&self) -> DriftType {
236 match self {
237 DriftProfile::Spc(_) => DriftType::Spc,
238 DriftProfile::Psi(_) => DriftType::Psi,
239 DriftProfile::Custom(_) => DriftType::Custom,
240 DriftProfile::LLM(_) => DriftType::LLM,
241 }
242 }
243
244 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
245 Ok(PyHelperFuncs::save_to_json(
246 self,
247 path,
248 FileName::DriftProfile.to_str(),
249 )?)
250 }
251
252 pub fn load_from_json(path: PathBuf) -> Result<Self, ProfileError> {
253 let file = std::fs::read_to_string(&path)?;
254 Ok(serde_json::from_str(&file)?)
255 }
256
257 pub fn from_profile_path(path: PathBuf) -> Result<Self, ProfileError> {
265 let profile = std::fs::read_to_string(&path)?;
266 let profile_value: Value = serde_json::from_str(&profile).unwrap();
267 DriftProfile::from_value(profile_value)
268 }
269
270 pub fn version(&self) -> Option<String> {
271 match self {
272 DriftProfile::Spc(profile) => Some(profile.config.version.clone()),
273 DriftProfile::Psi(profile) => Some(profile.config.version.clone()),
274 DriftProfile::Custom(profile) => Some(profile.config.version.clone()),
275 DriftProfile::LLM(profile) => Some(profile.config.version.clone()),
276 }
277 }
278
279 pub fn identifier(&self) -> String {
280 match self {
281 DriftProfile::Spc(profile) => {
282 format!(
283 "{}/{}/v{}/spc",
284 profile.config.space, profile.config.name, profile.config.version
285 )
286 }
287 DriftProfile::Psi(profile) => {
288 format!(
289 "{}/{}/v{}/psi",
290 profile.config.space, profile.config.name, profile.config.version
291 )
292 }
293 DriftProfile::Custom(profile) => {
294 format!(
295 "{}/{}/v{}/custom",
296 profile.config.space, profile.config.name, profile.config.version
297 )
298 }
299 DriftProfile::LLM(profile) => {
300 format!(
301 "{}/{}/v{}/llm",
302 profile.config.space, profile.config.name, profile.config.version
303 )
304 }
305 }
306 }
307}
308
309impl Default for DriftProfile {
310 fn default() -> Self {
311 DriftProfile::Spc(SpcDriftProfile::default())
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use tempfile::TempDir;
319
320 #[test]
321 fn test_drift_type_from_str_base() {
322 assert_eq!(DriftType::from_str("SPC").unwrap(), DriftType::Spc);
323 assert_eq!(DriftType::from_str("PSI").unwrap(), DriftType::Psi);
324 assert_eq!(DriftType::from_str("CUSTOM").unwrap(), DriftType::Custom);
325 assert!(DriftType::from_str("INVALID").is_err());
326 }
327
328 #[test]
329 fn test_drift_type_value_base() {
330 assert_eq!(DriftType::Spc.to_string(), "Spc");
331 assert_eq!(DriftType::Psi.to_string(), "Psi");
332 assert_eq!(DriftType::Custom.to_string(), "Custom");
333 }
334
335 #[test]
336 fn test_drift_profile_enum() {
337 let profile = DriftProfile::Spc(SpcDriftProfile::default());
338
339 let temp_dir = TempDir::new().unwrap();
341 let path = temp_dir.path().join("profile.json");
342
343 profile.save_to_json(Some(path.clone())).unwrap();
344
345 assert!(path.exists());
347
348 let loaded_profile = DriftProfile::load_from_json(path).unwrap();
350
351 assert_eq!(profile, loaded_profile);
353 }
354}