1use crate::custom::CustomDriftProfile;
2use crate::error::ProfileError;
3use crate::genai::profile::GenAIEvalProfile;
4use crate::psi::PsiDriftProfile;
5use crate::spc::SpcDriftProfile;
6use crate::util::ProfileBaseArgs;
7use crate::{AlertDispatchConfig, ProfileArgs};
8use crate::{FileName, PyHelperFuncs};
9use pyo3::prelude::*;
10use pyo3::IntoPyObjectExt;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::fmt::Display;
14use std::ops::Deref;
15use std::path::PathBuf;
16use std::str::FromStr;
17use strum_macros::EnumIter;
18#[pyclass(eq)]
19#[derive(Debug, EnumIter, PartialEq, Serialize, Deserialize, Clone, Default, Eq, Hash)]
20pub enum DriftType {
21 #[default]
22 Spc,
23 Psi,
24 Custom,
25 GenAI,
26}
27
28#[pymethods]
29impl DriftType {
30 #[staticmethod]
31 pub fn from_value(value: &str) -> Option<DriftType> {
32 match value.to_lowercase().as_str() {
33 "spc" => Some(DriftType::Spc),
34 "psi" => Some(DriftType::Psi),
35 "custom" => Some(DriftType::Custom),
36 "genai" => Some(DriftType::GenAI),
37 _ => None,
38 }
39 }
40
41 #[getter]
42 pub fn to_string(&self) -> &str {
43 match self {
44 DriftType::Spc => "Spc",
45 DriftType::Psi => "Psi",
46 DriftType::Custom => "Custom",
47 DriftType::GenAI => "GenAI",
48 }
49 }
50}
51
52impl Deref for DriftType {
53 type Target = str;
54
55 fn deref(&self) -> &Self::Target {
56 match self {
57 DriftType::Spc => "spc",
58 DriftType::Psi => "psi",
59 DriftType::Custom => "custom",
60 DriftType::GenAI => "genai",
61 }
62 }
63}
64
65impl FromStr for DriftType {
66 type Err = ProfileError;
67
68 fn from_str(value: &str) -> Result<Self, Self::Err> {
69 match value.to_lowercase().as_str() {
70 "spc" => Ok(DriftType::Spc),
71 "psi" => Ok(DriftType::Psi),
72 "custom" => Ok(DriftType::Custom),
73 "genai" => Ok(DriftType::GenAI),
74 _ => Err(ProfileError::InvalidDriftTypeError),
75 }
76 }
77}
78
79impl Display for DriftType {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 match self {
82 DriftType::Spc => write!(f, "Spc"),
83 DriftType::Psi => write!(f, "Psi"),
84 DriftType::Custom => write!(f, "Custom"),
85 DriftType::GenAI => write!(f, "LLM"),
86 }
87 }
88}
89
90pub struct DriftArgs {
91 pub name: String,
92 pub space: String,
93 pub version: String,
94 pub dispatch_config: AlertDispatchConfig,
95}
96
97#[pyclass]
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99pub enum DriftProfile {
100 Spc(SpcDriftProfile),
101 Psi(PsiDriftProfile),
102 Custom(CustomDriftProfile),
103 GenAI(GenAIEvalProfile),
104}
105
106#[pymethods]
107impl DriftProfile {
108 #[getter]
109 pub fn profile<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, ProfileError> {
110 match self {
111 DriftProfile::Spc(profile) => Ok(profile.clone().into_bound_py_any(py)?),
112 DriftProfile::Psi(profile) => Ok(profile.clone().into_bound_py_any(py)?),
113 DriftProfile::Custom(profile) => Ok(profile.clone().into_bound_py_any(py)?),
114 DriftProfile::GenAI(profile) => Ok(profile.clone().into_bound_py_any(py)?),
115 }
116 }
117}
118
119impl DriftProfile {
120 pub fn from_str(drift_type: &DriftType, profile: &str) -> Result<Self, ProfileError> {
132 match drift_type {
133 DriftType::Spc => {
134 let profile = serde_json::from_str(profile)?;
135 Ok(DriftProfile::Spc(profile))
136 }
137 DriftType::Psi => {
138 let profile = serde_json::from_str(profile)?;
139 Ok(DriftProfile::Psi(profile))
140 }
141 DriftType::Custom => {
142 let profile = serde_json::from_str(profile)?;
143 Ok(DriftProfile::Custom(profile))
144 }
145 DriftType::GenAI => {
146 let profile = serde_json::from_str(profile)?;
147 Ok(DriftProfile::GenAI(profile))
148 }
149 }
150 }
151
152 pub fn get_base_args(&self) -> ProfileArgs {
154 match self {
155 DriftProfile::Spc(profile) => profile.get_base_args(),
156 DriftProfile::Psi(profile) => profile.get_base_args(),
157 DriftProfile::Custom(profile) => profile.get_base_args(),
158 DriftProfile::GenAI(profile) => profile.get_base_args(),
159 }
160 }
161
162 pub fn to_value(&self) -> serde_json::Value {
163 match self {
164 DriftProfile::Spc(profile) => profile.to_value(),
165 DriftProfile::Psi(profile) => profile.to_value(),
166 DriftProfile::Custom(profile) => profile.to_value(),
167 DriftProfile::GenAI(profile) => profile.to_value(),
168 }
169 }
170
171 pub fn from_value(body: serde_json::Value) -> Result<Self, ProfileError> {
180 let drift_type = body["config"]["drift_type"].as_str().unwrap();
181
182 let drift_type = DriftType::from_str(drift_type)?;
183
184 match drift_type {
185 DriftType::Spc => {
186 let profile = serde_json::from_value(body)?;
187 Ok(DriftProfile::Spc(profile))
188 }
189 DriftType::Psi => {
190 let profile = serde_json::from_value(body)?;
191 Ok(DriftProfile::Psi(profile))
192 }
193 DriftType::Custom => {
194 let profile = serde_json::from_value(body)?;
195 Ok(DriftProfile::Custom(profile))
196 }
197 DriftType::GenAI => {
198 let profile = serde_json::from_value(body)?;
199 Ok(DriftProfile::GenAI(profile))
200 }
201 }
202 }
203
204 pub fn from_python(profile: &Bound<'_, PyAny>) -> Result<Self, ProfileError> {
205 let drift_type: DriftType = profile.getattr("drift_type")?.extract::<DriftType>()?;
206 match drift_type {
207 DriftType::Spc => {
208 let profile = profile.extract::<SpcDriftProfile>()?;
209 Ok(DriftProfile::Spc(profile))
210 }
211 DriftType::Psi => {
212 let profile = profile.extract::<PsiDriftProfile>()?;
213 Ok(DriftProfile::Psi(profile))
214 }
215 DriftType::Custom => {
216 let profile = profile.extract::<CustomDriftProfile>()?;
217 Ok(DriftProfile::Custom(profile))
218 }
219 DriftType::GenAI => {
220 let profile = profile.extract::<GenAIEvalProfile>()?;
221 Ok(DriftProfile::GenAI(profile))
222 }
223 }
224 }
225
226 pub fn get_spc_profile(&self) -> Result<&SpcDriftProfile, ProfileError> {
227 match self {
228 DriftProfile::Spc(profile) => Ok(profile),
229 _ => Err(ProfileError::InvalidDriftTypeError),
230 }
231 }
232
233 pub fn get_psi_profile(&self) -> Result<&PsiDriftProfile, ProfileError> {
234 match self {
235 DriftProfile::Psi(profile) => Ok(profile),
236 _ => Err(ProfileError::InvalidDriftTypeError),
237 }
238 }
239
240 pub fn get_genai_profile(&self) -> Result<&GenAIEvalProfile, ProfileError> {
241 match self {
242 DriftProfile::GenAI(profile) => Ok(profile),
243 _ => Err(ProfileError::InvalidDriftTypeError),
244 }
245 }
246
247 pub fn drift_type(&self) -> DriftType {
248 match self {
249 DriftProfile::Spc(_) => DriftType::Spc,
250 DriftProfile::Psi(_) => DriftType::Psi,
251 DriftProfile::Custom(_) => DriftType::Custom,
252 DriftProfile::GenAI(_) => DriftType::GenAI,
253 }
254 }
255
256 pub fn save_to_json(&self, path: Option<PathBuf>) -> Result<PathBuf, ProfileError> {
257 Ok(PyHelperFuncs::save_to_json(
258 self,
259 path,
260 FileName::DriftProfile.to_str(),
261 )?)
262 }
263
264 pub fn load_from_json(path: PathBuf) -> Result<Self, ProfileError> {
265 let file = std::fs::read_to_string(&path)?;
266 Ok(serde_json::from_str(&file)?)
267 }
268
269 pub fn from_profile_path(path: PathBuf) -> Result<Self, ProfileError> {
277 let profile = std::fs::read_to_string(&path)?;
278 let profile_value: Value = serde_json::from_str(&profile).unwrap();
279 DriftProfile::from_value(profile_value)
280 }
281
282 pub fn version(&self) -> Option<String> {
283 match self {
284 DriftProfile::Spc(profile) => Some(profile.config.version.clone()),
285 DriftProfile::Psi(profile) => Some(profile.config.version.clone()),
286 DriftProfile::Custom(profile) => Some(profile.config.version.clone()),
287 DriftProfile::GenAI(profile) => Some(profile.config.version.clone()),
288 }
289 }
290
291 pub fn identifier(&self) -> String {
292 match self {
293 DriftProfile::Spc(profile) => {
294 format!(
295 "{}/{}/v{}/spc",
296 profile.config.space, profile.config.name, profile.config.version
297 )
298 }
299 DriftProfile::Psi(profile) => {
300 format!(
301 "{}/{}/v{}/psi",
302 profile.config.space, profile.config.name, profile.config.version
303 )
304 }
305 DriftProfile::Custom(profile) => {
306 format!(
307 "{}/{}/v{}/custom",
308 profile.config.space, profile.config.name, profile.config.version
309 )
310 }
311 DriftProfile::GenAI(profile) => {
312 format!(
313 "{}/{}/v{}/genai",
314 profile.config.space, profile.config.name, profile.config.version
315 )
316 }
317 }
318 }
319
320 pub fn uid(&self) -> &String {
321 match self {
322 DriftProfile::Spc(profile) => &profile.config.uid,
323 DriftProfile::Psi(profile) => &profile.config.uid,
324 DriftProfile::Custom(profile) => &profile.config.uid,
325 DriftProfile::GenAI(profile) => &profile.config.uid,
326 }
327 }
328}
329
330impl Default for DriftProfile {
331 fn default() -> Self {
332 DriftProfile::Spc(SpcDriftProfile::default())
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use tempfile::TempDir;
340
341 #[test]
342 fn test_drift_type_from_str_base() {
343 assert_eq!(DriftType::from_str("SPC").unwrap(), DriftType::Spc);
344 assert_eq!(DriftType::from_str("PSI").unwrap(), DriftType::Psi);
345 assert_eq!(DriftType::from_str("CUSTOM").unwrap(), DriftType::Custom);
346 assert!(DriftType::from_str("INVALID").is_err());
347 }
348
349 #[test]
350 fn test_drift_type_value_base() {
351 assert_eq!(DriftType::Spc.to_string(), "Spc");
352 assert_eq!(DriftType::Psi.to_string(), "Psi");
353 assert_eq!(DriftType::Custom.to_string(), "Custom");
354 }
355
356 #[test]
357 fn test_drift_profile_enum() {
358 let profile = DriftProfile::Spc(SpcDriftProfile::default());
359
360 let temp_dir = TempDir::new().unwrap();
362 let path = temp_dir.path().join("profile.json");
363
364 profile.save_to_json(Some(path.clone())).unwrap();
365
366 assert!(path.exists());
368
369 let loaded_profile = DriftProfile::load_from_json(path).unwrap();
371
372 assert_eq!(profile, loaded_profile);
374 }
375}