scouter_types/
util.rs

1use crate::error::{ProfileError, TypeError, UtilError};
2use crate::FeatureMap;
3use crate::{CommonCrons, DriftType};
4use chrono::{DateTime, Utc};
5use colored_json::{Color, ColorMode, ColoredFormatter, PrettyFormatter, Styler};
6use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyString};
9use pyo3::IntoPyObjectExt;
10use rayon::prelude::*;
11use serde::Serialize;
12use serde_json::{json, Value};
13use std::collections::{BTreeSet, HashMap};
14use std::fmt::{Display, Formatter};
15use std::path::PathBuf;
16use std::str::FromStr;
17
18pub const MISSING: &str = "__missing__";
19pub const DEFAULT_VERSION: &str = "0.1.0";
20
21pub enum FileName {
22    SpcDriftMap,
23    SpcDriftProfile,
24    PsiDriftMap,
25    PsiDriftProfile,
26    CustomDriftProfile,
27    DriftProfile,
28    DataProfile,
29}
30
31impl FileName {
32    pub fn to_str(&self) -> &'static str {
33        match self {
34            FileName::SpcDriftMap => "spc_drift_map.json",
35            FileName::SpcDriftProfile => "spc_drift_profile.json",
36            FileName::PsiDriftMap => "psi_drift_map.json",
37            FileName::PsiDriftProfile => "psi_drift_profile.json",
38            FileName::CustomDriftProfile => "custom_drift_profile.json",
39            FileName::DataProfile => "data_profile.json",
40            FileName::DriftProfile => "drift_profile.json",
41        }
42    }
43}
44
45pub struct ProfileFuncs {}
46
47impl ProfileFuncs {
48    pub fn __str__<T: Serialize>(object: T) -> String {
49        match ColoredFormatter::with_styler(
50            PrettyFormatter::default(),
51            Styler {
52                key: Color::Rgb(245, 77, 85).bold(),
53                string_value: Color::Rgb(249, 179, 93).foreground(),
54                float_value: Color::Rgb(249, 179, 93).foreground(),
55                integer_value: Color::Rgb(249, 179, 93).foreground(),
56                bool_value: Color::Rgb(249, 179, 93).foreground(),
57                nil_value: Color::Rgb(249, 179, 93).foreground(),
58                ..Default::default()
59            },
60        )
61        .to_colored_json(&object, ColorMode::On)
62        {
63            Ok(json) => json,
64            Err(e) => format!("Failed to serialize to json: {}", e),
65        }
66        // serialize the struct to a string
67    }
68
69    pub fn __json__<T: Serialize>(object: T) -> String {
70        match serde_json::to_string_pretty(&object) {
71            Ok(json) => json,
72            Err(e) => format!("Failed to serialize to json: {}", e),
73        }
74    }
75
76    pub fn save_to_json<T>(
77        model: T,
78        path: Option<PathBuf>,
79        filename: &str,
80    ) -> Result<PathBuf, UtilError>
81    where
82        T: Serialize,
83    {
84        // serialize the struct to a string
85        let json = serde_json::to_string_pretty(&model)?;
86
87        // check if path is provided
88        let write_path = if path.is_some() {
89            let mut new_path = path.ok_or(UtilError::CreatePathError)?;
90
91            // ensure .json extension
92            new_path.set_extension("json");
93
94            if !new_path.exists() {
95                // ensure path exists, create if not
96                let parent_path = new_path.parent().ok_or(UtilError::GetParentPathError)?;
97
98                std::fs::create_dir_all(parent_path)
99                    .map_err(|_| UtilError::CreateDirectoryError)?;
100            }
101
102            new_path
103        } else {
104            PathBuf::from(filename)
105        };
106
107        std::fs::write(&write_path, json)?;
108
109        Ok(write_path)
110    }
111}
112
113pub fn json_to_pyobject(py: Python, value: &Value, dict: &Bound<'_, PyDict>) -> PyResult<()> {
114    match value {
115        Value::Object(map) => {
116            for (k, v) in map {
117                let py_value = match v {
118                    Value::Null => py.None(),
119                    Value::Bool(b) => b.into_py_any(py).unwrap(),
120                    Value::Number(n) => {
121                        if let Some(i) = n.as_i64() {
122                            i.into_py_any(py).unwrap()
123                        } else if let Some(f) = n.as_f64() {
124                            f.into_py_any(py).unwrap()
125                        } else {
126                            return Err(PyRuntimeError::new_err(
127                                "Invalid number type, expected i64 or f64",
128                            ));
129                        }
130                    }
131                    Value::String(s) => s.into_py_any(py).unwrap(),
132                    Value::Array(arr) => {
133                        let py_list = PyList::empty(py);
134                        for item in arr {
135                            let py_item = json_to_pyobject_value(py, item)?;
136                            py_list.append(py_item)?;
137                        }
138                        py_list.into_py_any(py).unwrap()
139                    }
140                    Value::Object(_) => {
141                        let nested_dict = PyDict::new(py);
142                        json_to_pyobject(py, v, &nested_dict)?;
143                        nested_dict.into_py_any(py).unwrap()
144                    }
145                };
146                dict.set_item(k, py_value)?;
147            }
148        }
149        _ => return Err(PyRuntimeError::new_err("Root must be object")),
150    }
151    Ok(())
152}
153
154pub fn json_to_pyobject_value(py: Python, value: &Value) -> PyResult<PyObject> {
155    Ok(match value {
156        Value::Null => py.None(),
157        Value::Bool(b) => b.into_py_any(py).unwrap(),
158        Value::Number(n) => {
159            if let Some(i) = n.as_i64() {
160                i.into_py_any(py).unwrap()
161            } else if let Some(f) = n.as_f64() {
162                f.into_py_any(py).unwrap()
163            } else {
164                return Err(PyRuntimeError::new_err(
165                    "Invalid number type, expected i64 or f64",
166                ));
167            }
168        }
169        Value::String(s) => s.into_py_any(py).unwrap(),
170        Value::Array(arr) => {
171            let py_list = PyList::empty(py);
172            for item in arr {
173                let py_item = json_to_pyobject_value(py, item)?;
174                py_list.append(py_item)?;
175            }
176            py_list.into_py_any(py).unwrap()
177        }
178        Value::Object(_) => {
179            let nested_dict = PyDict::new(py);
180            json_to_pyobject(py, value, &nested_dict)?;
181            nested_dict.into_py_any(py).unwrap()
182        }
183    })
184}
185
186pub fn pyobject_to_json(obj: &Bound<'_, PyAny>) -> PyResult<Value> {
187    if obj.is_instance_of::<PyDict>() {
188        let dict = obj.downcast::<PyDict>()?;
189        let mut map = serde_json::Map::new();
190        for (key, value) in dict.iter() {
191            let key_str = key.extract::<String>()?;
192            let json_value = pyobject_to_json(&value)?;
193            map.insert(key_str, json_value);
194        }
195        Ok(Value::Object(map))
196    } else if obj.is_instance_of::<PyList>() {
197        let list = obj.downcast::<PyList>()?;
198        let mut vec = Vec::new();
199        for item in list.iter() {
200            vec.push(pyobject_to_json(&item)?);
201        }
202        Ok(Value::Array(vec))
203    } else if obj.is_instance_of::<PyString>() {
204        let s = obj.extract::<String>()?;
205        Ok(Value::String(s))
206    } else if obj.is_instance_of::<PyFloat>() {
207        let f = obj.extract::<f64>()?;
208        Ok(json!(f))
209    } else if obj.is_instance_of::<PyBool>() {
210        let b = obj.extract::<bool>()?;
211        Ok(json!(b))
212    } else if obj.is_instance_of::<PyInt>() {
213        let i = obj.extract::<i64>()?;
214        Ok(json!(i))
215    } else if obj.is_none() {
216        Ok(Value::Null)
217    } else {
218        Err(PyRuntimeError::new_err("Unsupported type"))
219    }
220}
221
222pub fn create_feature_map(
223    features: &[String],
224    array: &[Vec<String>],
225) -> Result<FeatureMap, ProfileError> {
226    // check if features and array are the same length
227    if features.len() != array.len() {
228        return Err(ProfileError::FeatureArrayLengthError);
229    };
230
231    let feature_map = array
232        .par_iter()
233        .enumerate()
234        .map(|(i, col)| {
235            let unique = col
236                .iter()
237                .collect::<BTreeSet<_>>()
238                .into_iter()
239                .collect::<Vec<_>>();
240            let mut map = HashMap::new();
241            for (j, item) in unique.iter().enumerate() {
242                map.insert(item.to_string(), j);
243
244                // check if j is last index
245                if j == unique.len() - 1 {
246                    // insert missing value
247                    map.insert("missing".to_string(), j + 1);
248                }
249            }
250
251            (features[i].to_string(), map)
252        })
253        .collect::<HashMap<_, _>>();
254
255    Ok(FeatureMap {
256        features: feature_map,
257    })
258}
259
260#[derive(PartialEq, Debug)]
261pub struct ProfileArgs {
262    pub name: String,
263    pub space: String,
264    pub version: String,
265    pub schedule: String,
266    pub scouter_version: String,
267    pub drift_type: DriftType,
268}
269
270// trait to implement on all profile types
271pub trait ProfileBaseArgs {
272    fn get_base_args(&self) -> ProfileArgs;
273    fn to_value(&self) -> serde_json::Value;
274}
275
276pub trait ValidateAlertConfig {
277    fn resolve_schedule(schedule: &str) -> String {
278        let default_schedule = CommonCrons::EveryDay.cron();
279
280        cron::Schedule::from_str(schedule) // Pass by reference here
281            .map(|_| schedule) // If valid, return the schedule
282            .unwrap_or_else(|_| {
283                tracing::error!("Invalid cron schedule, using default schedule");
284                &default_schedule
285            })
286            .to_string()
287    }
288}
289
290#[pyclass(eq)]
291#[derive(PartialEq, Debug)]
292pub enum DataType {
293    Pandas,
294    Polars,
295    Numpy,
296    Arrow,
297    Unknown,
298}
299
300impl Display for DataType {
301    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
302        match self {
303            DataType::Pandas => write!(f, "pandas"),
304            DataType::Polars => write!(f, "polars"),
305            DataType::Numpy => write!(f, "numpy"),
306            DataType::Arrow => write!(f, "arrow"),
307            DataType::Unknown => write!(f, "unknown"),
308        }
309    }
310}
311
312impl DataType {
313    pub fn from_module_name(module_name: &str) -> Result<Self, TypeError> {
314        match module_name {
315            "pandas.core.frame.DataFrame" => Ok(DataType::Pandas),
316            "polars.dataframe.frame.DataFrame" => Ok(DataType::Polars),
317            "numpy.ndarray" => Ok(DataType::Numpy),
318            "pyarrow.lib.Table" => Ok(DataType::Arrow),
319            _ => Err(TypeError::InvalidDataType),
320        }
321    }
322}
323
324pub fn get_utc_datetime() -> DateTime<Utc> {
325    Utc::now()
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    pub struct TestStruct;
333    impl ValidateAlertConfig for TestStruct {}
334
335    #[test]
336    fn test_resolve_schedule_base() {
337        let valid_schedule = "0 0 5 * * *"; // Every day at 5:00 AM
338
339        let result = TestStruct::resolve_schedule(valid_schedule);
340
341        assert_eq!(result, "0 0 5 * * *".to_string());
342
343        let invalid_schedule = "invalid_cron";
344
345        let default_schedule = CommonCrons::EveryDay.cron();
346
347        let result = TestStruct::resolve_schedule(invalid_schedule);
348
349        assert_eq!(result, default_schedule);
350    }
351}