1use crate::error::{ProfileError, TypeError, UtilError};
2use crate::FeatureMap;
3use crate::{CommonCrons, DriftType};
4use chrono::{DateTime, Utc};
5use colored_json::{Color, ColorMode, ColoredFormatter, PrettyFormatter, Styler};
6use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyString};
9use pyo3::IntoPyObjectExt;
10use rayon::prelude::*;
11use serde::{Deserialize, Serialize};
12use serde_json::{json, Value};
13use std::collections::{BTreeSet, HashMap};
14use std::fmt::{Display, Formatter};
15use std::path::PathBuf;
16use std::str::FromStr;
17
18pub const MISSING: &str = "__missing__";
19pub const DEFAULT_VERSION: &str = "0.0.0";
20
21pub fn scouter_version() -> String {
22 env!("CARGO_PKG_VERSION").to_string()
23}
24
25pub enum FileName {
26 SpcDriftMap,
27 SpcDriftProfile,
28 PsiDriftMap,
29 PsiDriftProfile,
30 CustomDriftProfile,
31 DriftProfile,
32 DataProfile,
33 LLMDriftProfile,
34}
35
36impl FileName {
37 pub fn to_str(&self) -> &'static str {
38 match self {
39 FileName::SpcDriftMap => "spc_drift_map.json",
40 FileName::SpcDriftProfile => "spc_drift_profile.json",
41 FileName::PsiDriftMap => "psi_drift_map.json",
42 FileName::PsiDriftProfile => "psi_drift_profile.json",
43 FileName::CustomDriftProfile => "custom_drift_profile.json",
44 FileName::DataProfile => "data_profile.json",
45 FileName::DriftProfile => "drift_profile.json",
46 FileName::LLMDriftProfile => "llm_drift_profile.json",
47 }
48 }
49}
50
51pub struct PyHelperFuncs {}
52
53impl PyHelperFuncs {
54 pub fn __str__<T: Serialize>(object: T) -> String {
55 match ColoredFormatter::with_styler(
56 PrettyFormatter::default(),
57 Styler {
58 key: Color::Rgb(245, 77, 85).bold(),
59 string_value: Color::Rgb(249, 179, 93).foreground(),
60 float_value: Color::Rgb(249, 179, 93).foreground(),
61 integer_value: Color::Rgb(249, 179, 93).foreground(),
62 bool_value: Color::Rgb(249, 179, 93).foreground(),
63 nil_value: Color::Rgb(249, 179, 93).foreground(),
64 ..Default::default()
65 },
66 )
67 .to_colored_json(&object, ColorMode::On)
68 {
69 Ok(json) => json,
70 Err(e) => format!("Failed to serialize to json: {e}"),
71 }
72 }
74
75 pub fn __json__<T: Serialize>(object: T) -> String {
76 match serde_json::to_string_pretty(&object) {
77 Ok(json) => json,
78 Err(e) => format!("Failed to serialize to json: {e}"),
79 }
80 }
81
82 pub fn save_to_json<T>(
83 model: T,
84 path: Option<PathBuf>,
85 filename: &str,
86 ) -> Result<PathBuf, UtilError>
87 where
88 T: Serialize,
89 {
90 let json = serde_json::to_string_pretty(&model)?;
92
93 let write_path = if path.is_some() {
95 let mut new_path = path.ok_or(UtilError::CreatePathError)?;
96
97 new_path.set_extension("json");
99
100 if !new_path.exists() {
101 let parent_path = new_path.parent().ok_or(UtilError::GetParentPathError)?;
103
104 std::fs::create_dir_all(parent_path)
105 .map_err(|_| UtilError::CreateDirectoryError)?;
106 }
107
108 new_path
109 } else {
110 PathBuf::from(filename)
111 };
112
113 std::fs::write(&write_path, json)?;
114
115 Ok(write_path)
116 }
117}
118
119pub fn json_to_pyobject(py: Python, value: &Value, dict: &Bound<'_, PyDict>) -> PyResult<()> {
120 match value {
121 Value::Object(map) => {
122 for (k, v) in map {
123 let py_value = match v {
124 Value::Null => py.None(),
125 Value::Bool(b) => b.into_py_any(py).unwrap(),
126 Value::Number(n) => {
127 if let Some(i) = n.as_i64() {
128 i.into_py_any(py).unwrap()
129 } else if let Some(f) = n.as_f64() {
130 f.into_py_any(py).unwrap()
131 } else {
132 return Err(PyRuntimeError::new_err(
133 "Invalid number type, expected i64 or f64",
134 ));
135 }
136 }
137 Value::String(s) => s.into_py_any(py).unwrap(),
138 Value::Array(arr) => {
139 let py_list = PyList::empty(py);
140 for item in arr {
141 let py_item = json_to_pyobject_value(py, item)?;
142 py_list.append(py_item)?;
143 }
144 py_list.into_py_any(py).unwrap()
145 }
146 Value::Object(_) => {
147 let nested_dict = PyDict::new(py);
148 json_to_pyobject(py, v, &nested_dict)?;
149 nested_dict.into_py_any(py).unwrap()
150 }
151 };
152 dict.set_item(k, py_value)?;
153 }
154 }
155 _ => return Err(PyRuntimeError::new_err("Root must be object")),
156 }
157 Ok(())
158}
159
160pub fn json_to_pyobject_value(py: Python, value: &Value) -> PyResult<PyObject> {
161 Ok(match value {
162 Value::Null => py.None(),
163 Value::Bool(b) => b.into_py_any(py).unwrap(),
164 Value::Number(n) => {
165 if let Some(i) = n.as_i64() {
166 i.into_py_any(py).unwrap()
167 } else if let Some(f) = n.as_f64() {
168 f.into_py_any(py).unwrap()
169 } else {
170 return Err(PyRuntimeError::new_err(
171 "Invalid number type, expected i64 or f64",
172 ));
173 }
174 }
175 Value::String(s) => s.into_py_any(py).unwrap(),
176 Value::Array(arr) => {
177 let py_list = PyList::empty(py);
178 for item in arr {
179 let py_item = json_to_pyobject_value(py, item)?;
180 py_list.append(py_item)?;
181 }
182 py_list.into_py_any(py).unwrap()
183 }
184 Value::Object(_) => {
185 let nested_dict = PyDict::new(py);
186 json_to_pyobject(py, value, &nested_dict)?;
187 nested_dict.into_py_any(py).unwrap()
188 }
189 })
190}
191
192pub fn pyobject_to_json(obj: &Bound<'_, PyAny>) -> PyResult<Value> {
193 if obj.is_instance_of::<PyDict>() {
194 let dict = obj.downcast::<PyDict>()?;
195 let mut map = serde_json::Map::new();
196 for (key, value) in dict.iter() {
197 let key_str = key.extract::<String>()?;
198 let json_value = pyobject_to_json(&value)?;
199 map.insert(key_str, json_value);
200 }
201 Ok(Value::Object(map))
202 } else if obj.is_instance_of::<PyList>() {
203 let list = obj.downcast::<PyList>()?;
204 let mut vec = Vec::new();
205 for item in list.iter() {
206 vec.push(pyobject_to_json(&item)?);
207 }
208 Ok(Value::Array(vec))
209 } else if obj.is_instance_of::<PyString>() {
210 let s = obj.extract::<String>()?;
211 Ok(Value::String(s))
212 } else if obj.is_instance_of::<PyFloat>() {
213 let f = obj.extract::<f64>()?;
214 Ok(json!(f))
215 } else if obj.is_instance_of::<PyBool>() {
216 let b = obj.extract::<bool>()?;
217 Ok(json!(b))
218 } else if obj.is_instance_of::<PyInt>() {
219 let i = obj.extract::<i64>()?;
220 Ok(json!(i))
221 } else if obj.is_none() {
222 Ok(Value::Null)
223 } else {
224 Err(PyRuntimeError::new_err("Unsupported type"))
225 }
226}
227
228pub fn create_feature_map(
229 features: &[String],
230 array: &[Vec<String>],
231) -> Result<FeatureMap, ProfileError> {
232 if features.len() != array.len() {
234 return Err(ProfileError::FeatureArrayLengthError);
235 };
236
237 let feature_map = array
238 .par_iter()
239 .enumerate()
240 .map(|(i, col)| {
241 let unique = col
242 .iter()
243 .collect::<BTreeSet<_>>()
244 .into_iter()
245 .collect::<Vec<_>>();
246 let mut map = HashMap::new();
247 for (j, item) in unique.iter().enumerate() {
248 map.insert(item.to_string(), j);
249
250 if j == unique.len() - 1 {
252 map.insert("missing".to_string(), j + 1);
254 }
255 }
256
257 (features[i].to_string(), map)
258 })
259 .collect::<HashMap<_, _>>();
260
261 Ok(FeatureMap {
262 features: feature_map,
263 })
264}
265
266pub fn is_pydantic_model(py: Python, obj: &Bound<'_, PyAny>) -> Result<bool, TypeError> {
274 let pydantic = match py.import("pydantic") {
275 Ok(module) => module,
276 Err(e) => return Err(TypeError::FailedToImportPydantic(e.to_string())),
277 };
278 let basemodel = pydantic.getattr("BaseModel")?;
279
280 let is_basemodel = obj
282 .is_instance(&basemodel)
283 .map_err(|e| TypeError::FailedToCheckPydanticModel(e.to_string()))?;
284
285 Ok(is_basemodel)
286}
287
288#[derive(PartialEq, Debug)]
289pub struct ProfileArgs {
290 pub name: String,
291 pub space: String,
292 pub version: Option<String>,
293 pub schedule: String,
294 pub scouter_version: String,
295 pub drift_type: DriftType,
296}
297
298pub trait ProfileBaseArgs {
300 fn get_base_args(&self) -> ProfileArgs;
301 fn to_value(&self) -> serde_json::Value;
302}
303
304pub trait ValidateAlertConfig {
305 fn resolve_schedule(schedule: &str) -> String {
306 let default_schedule = CommonCrons::EveryDay.cron();
307
308 cron::Schedule::from_str(schedule) .map(|_| schedule) .unwrap_or_else(|_| {
311 tracing::error!("Invalid cron schedule, using default schedule");
312 &default_schedule
313 })
314 .to_string()
315 }
316}
317
318#[pyclass(eq)]
319#[derive(PartialEq, Debug)]
320pub enum DataType {
321 Pandas,
322 Polars,
323 Numpy,
324 Arrow,
325 Unknown,
326 LLM,
327}
328
329impl Display for DataType {
330 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
331 match self {
332 DataType::Pandas => write!(f, "pandas"),
333 DataType::Polars => write!(f, "polars"),
334 DataType::Numpy => write!(f, "numpy"),
335 DataType::Arrow => write!(f, "arrow"),
336 DataType::Unknown => write!(f, "unknown"),
337 DataType::LLM => write!(f, "llm"),
338 }
339 }
340}
341
342impl DataType {
343 pub fn from_module_name(module_name: &str) -> Result<Self, TypeError> {
344 match module_name {
345 "pandas.core.frame.DataFrame" => Ok(DataType::Pandas),
346 "polars.dataframe.frame.DataFrame" => Ok(DataType::Polars),
347 "numpy.ndarray" => Ok(DataType::Numpy),
348 "pyarrow.lib.Table" => Ok(DataType::Arrow),
349 "scouter_drift.llm.LLMRecord" => Ok(DataType::LLM),
350 _ => Err(TypeError::InvalidDataType),
351 }
352 }
353}
354
355pub fn get_utc_datetime() -> DateTime<Utc> {
356 Utc::now()
357}
358
359#[pyclass(eq)]
360#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
361pub enum AlertThreshold {
362 Below,
363 Above,
364 Outside,
365}
366
367impl Display for AlertThreshold {
368 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
369 write!(f, "{self:?}")
370 }
371}
372
373#[pymethods]
374impl AlertThreshold {
375 #[staticmethod]
376 pub fn from_value(value: &str) -> Option<Self> {
377 match value.to_lowercase().as_str() {
378 "below" => Some(AlertThreshold::Below),
379 "above" => Some(AlertThreshold::Above),
380 "outside" => Some(AlertThreshold::Outside),
381 _ => None,
382 }
383 }
384
385 pub fn __str__(&self) -> String {
386 match self {
387 AlertThreshold::Below => "Below".to_string(),
388 AlertThreshold::Above => "Above".to_string(),
389 AlertThreshold::Outside => "Outside".to_string(),
390 }
391 }
392}
393
394#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
395pub enum Status {
396 #[default]
397 All,
398 Pending,
399 Processing,
400 Processed,
401 Failed,
402}
403
404impl Status {
405 pub fn as_str(&self) -> Option<&'static str> {
406 match self {
407 Status::All => None,
408 Status::Pending => Some("pending"),
409 Status::Processing => Some("processing"),
410 Status::Processed => Some("processed"),
411 Status::Failed => Some("failed"),
412 }
413 }
414}
415
416impl FromStr for Status {
417 type Err = TypeError;
418
419 fn from_str(s: &str) -> Result<Self, Self::Err> {
420 match s.to_lowercase().as_str() {
421 "all" => Ok(Status::All),
422 "pending" => Ok(Status::Pending),
423 "processing" => Ok(Status::Processing),
424 "processed" => Ok(Status::Processed),
425 "failed" => Ok(Status::Failed),
426 _ => Err(TypeError::InvalidStatusError(s.to_string())),
427 }
428 }
429}
430
431impl Display for Status {
432 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
433 match self {
434 Status::All => write!(f, "all"),
435 Status::Pending => write!(f, "pending"),
436 Status::Processing => write!(f, "processing"),
437 Status::Processed => write!(f, "processed"),
438 Status::Failed => write!(f, "failed"),
439 }
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 pub struct TestStruct;
448 impl ValidateAlertConfig for TestStruct {}
449
450 #[test]
451 fn test_resolve_schedule_base() {
452 let valid_schedule = "0 0 5 * * *"; let result = TestStruct::resolve_schedule(valid_schedule);
455
456 assert_eq!(result, "0 0 5 * * *".to_string());
457
458 let invalid_schedule = "invalid_cron";
459
460 let default_schedule = CommonCrons::EveryDay.cron();
461
462 let result = TestStruct::resolve_schedule(invalid_schedule);
463
464 assert_eq!(result, default_schedule);
465 }
466}