scouter_types/
util.rs

1use crate::error::{ProfileError, TypeError, UtilError};
2use crate::FeatureMap;
3use crate::{CommonCrons, DriftType};
4use chrono::{DateTime, Utc};
5use colored_json::{Color, ColorMode, ColoredFormatter, PrettyFormatter, Styler};
6use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyString};
9use pyo3::IntoPyObjectExt;
10use rayon::prelude::*;
11use serde::{Deserialize, Serialize};
12use serde_json::{json, Value};
13use std::collections::{BTreeSet, HashMap};
14use std::fmt::{Display, Formatter};
15use std::path::PathBuf;
16use std::str::FromStr;
17
18pub const MISSING: &str = "__missing__";
19pub const DEFAULT_VERSION: &str = "0.0.0";
20
21pub fn scouter_version() -> String {
22    env!("CARGO_PKG_VERSION").to_string()
23}
24
25pub enum FileName {
26    SpcDriftMap,
27    SpcDriftProfile,
28    PsiDriftMap,
29    PsiDriftProfile,
30    CustomDriftProfile,
31    DriftProfile,
32    DataProfile,
33    LLMDriftProfile,
34}
35
36impl FileName {
37    pub fn to_str(&self) -> &'static str {
38        match self {
39            FileName::SpcDriftMap => "spc_drift_map.json",
40            FileName::SpcDriftProfile => "spc_drift_profile.json",
41            FileName::PsiDriftMap => "psi_drift_map.json",
42            FileName::PsiDriftProfile => "psi_drift_profile.json",
43            FileName::CustomDriftProfile => "custom_drift_profile.json",
44            FileName::DataProfile => "data_profile.json",
45            FileName::DriftProfile => "drift_profile.json",
46            FileName::LLMDriftProfile => "llm_drift_profile.json",
47        }
48    }
49}
50
51pub struct PyHelperFuncs {}
52
53impl PyHelperFuncs {
54    pub fn __str__<T: Serialize>(object: T) -> String {
55        match ColoredFormatter::with_styler(
56            PrettyFormatter::default(),
57            Styler {
58                key: Color::Rgb(245, 77, 85).bold(),
59                string_value: Color::Rgb(249, 179, 93).foreground(),
60                float_value: Color::Rgb(249, 179, 93).foreground(),
61                integer_value: Color::Rgb(249, 179, 93).foreground(),
62                bool_value: Color::Rgb(249, 179, 93).foreground(),
63                nil_value: Color::Rgb(249, 179, 93).foreground(),
64                ..Default::default()
65            },
66        )
67        .to_colored_json(&object, ColorMode::On)
68        {
69            Ok(json) => json,
70            Err(e) => format!("Failed to serialize to json: {e}"),
71        }
72        // serialize the struct to a string
73    }
74
75    pub fn __json__<T: Serialize>(object: T) -> String {
76        match serde_json::to_string_pretty(&object) {
77            Ok(json) => json,
78            Err(e) => format!("Failed to serialize to json: {e}"),
79        }
80    }
81
82    pub fn save_to_json<T>(
83        model: T,
84        path: Option<PathBuf>,
85        filename: &str,
86    ) -> Result<PathBuf, UtilError>
87    where
88        T: Serialize,
89    {
90        // serialize the struct to a string
91        let json = serde_json::to_string_pretty(&model)?;
92
93        // check if path is provided
94        let write_path = if path.is_some() {
95            let mut new_path = path.ok_or(UtilError::CreatePathError)?;
96
97            // ensure .json extension
98            new_path.set_extension("json");
99
100            if !new_path.exists() {
101                // ensure path exists, create if not
102                let parent_path = new_path.parent().ok_or(UtilError::GetParentPathError)?;
103
104                std::fs::create_dir_all(parent_path)
105                    .map_err(|_| UtilError::CreateDirectoryError)?;
106            }
107
108            new_path
109        } else {
110            PathBuf::from(filename)
111        };
112
113        std::fs::write(&write_path, json)?;
114
115        Ok(write_path)
116    }
117}
118
119pub fn json_to_pyobject(py: Python, value: &Value, dict: &Bound<'_, PyDict>) -> PyResult<()> {
120    match value {
121        Value::Object(map) => {
122            for (k, v) in map {
123                let py_value = match v {
124                    Value::Null => py.None(),
125                    Value::Bool(b) => b.into_py_any(py).unwrap(),
126                    Value::Number(n) => {
127                        if let Some(i) = n.as_i64() {
128                            i.into_py_any(py).unwrap()
129                        } else if let Some(f) = n.as_f64() {
130                            f.into_py_any(py).unwrap()
131                        } else {
132                            return Err(PyRuntimeError::new_err(
133                                "Invalid number type, expected i64 or f64",
134                            ));
135                        }
136                    }
137                    Value::String(s) => s.into_py_any(py).unwrap(),
138                    Value::Array(arr) => {
139                        let py_list = PyList::empty(py);
140                        for item in arr {
141                            let py_item = json_to_pyobject_value(py, item)?;
142                            py_list.append(py_item)?;
143                        }
144                        py_list.into_py_any(py).unwrap()
145                    }
146                    Value::Object(_) => {
147                        let nested_dict = PyDict::new(py);
148                        json_to_pyobject(py, v, &nested_dict)?;
149                        nested_dict.into_py_any(py).unwrap()
150                    }
151                };
152                dict.set_item(k, py_value)?;
153            }
154        }
155        _ => return Err(PyRuntimeError::new_err("Root must be object")),
156    }
157    Ok(())
158}
159
160pub fn json_to_pyobject_value(py: Python, value: &Value) -> PyResult<PyObject> {
161    Ok(match value {
162        Value::Null => py.None(),
163        Value::Bool(b) => b.into_py_any(py).unwrap(),
164        Value::Number(n) => {
165            if let Some(i) = n.as_i64() {
166                i.into_py_any(py).unwrap()
167            } else if let Some(f) = n.as_f64() {
168                f.into_py_any(py).unwrap()
169            } else {
170                return Err(PyRuntimeError::new_err(
171                    "Invalid number type, expected i64 or f64",
172                ));
173            }
174        }
175        Value::String(s) => s.into_py_any(py).unwrap(),
176        Value::Array(arr) => {
177            let py_list = PyList::empty(py);
178            for item in arr {
179                let py_item = json_to_pyobject_value(py, item)?;
180                py_list.append(py_item)?;
181            }
182            py_list.into_py_any(py).unwrap()
183        }
184        Value::Object(_) => {
185            let nested_dict = PyDict::new(py);
186            json_to_pyobject(py, value, &nested_dict)?;
187            nested_dict.into_py_any(py).unwrap()
188        }
189    })
190}
191
192pub fn pyobject_to_json(obj: &Bound<'_, PyAny>) -> PyResult<Value> {
193    if obj.is_instance_of::<PyDict>() {
194        let dict = obj.downcast::<PyDict>()?;
195        let mut map = serde_json::Map::new();
196        for (key, value) in dict.iter() {
197            let key_str = key.extract::<String>()?;
198            let json_value = pyobject_to_json(&value)?;
199            map.insert(key_str, json_value);
200        }
201        Ok(Value::Object(map))
202    } else if obj.is_instance_of::<PyList>() {
203        let list = obj.downcast::<PyList>()?;
204        let mut vec = Vec::new();
205        for item in list.iter() {
206            vec.push(pyobject_to_json(&item)?);
207        }
208        Ok(Value::Array(vec))
209    } else if obj.is_instance_of::<PyString>() {
210        let s = obj.extract::<String>()?;
211        Ok(Value::String(s))
212    } else if obj.is_instance_of::<PyFloat>() {
213        let f = obj.extract::<f64>()?;
214        Ok(json!(f))
215    } else if obj.is_instance_of::<PyBool>() {
216        let b = obj.extract::<bool>()?;
217        Ok(json!(b))
218    } else if obj.is_instance_of::<PyInt>() {
219        let i = obj.extract::<i64>()?;
220        Ok(json!(i))
221    } else if obj.is_none() {
222        Ok(Value::Null)
223    } else {
224        Err(PyRuntimeError::new_err("Unsupported type"))
225    }
226}
227
228pub fn create_feature_map(
229    features: &[String],
230    array: &[Vec<String>],
231) -> Result<FeatureMap, ProfileError> {
232    // check if features and array are the same length
233    if features.len() != array.len() {
234        return Err(ProfileError::FeatureArrayLengthError);
235    };
236
237    let feature_map = array
238        .par_iter()
239        .enumerate()
240        .map(|(i, col)| {
241            let unique = col
242                .iter()
243                .collect::<BTreeSet<_>>()
244                .into_iter()
245                .collect::<Vec<_>>();
246            let mut map = HashMap::new();
247            for (j, item) in unique.iter().enumerate() {
248                map.insert(item.to_string(), j);
249
250                // check if j is last index
251                if j == unique.len() - 1 {
252                    // insert missing value
253                    map.insert("missing".to_string(), j + 1);
254                }
255            }
256
257            (features[i].to_string(), map)
258        })
259        .collect::<HashMap<_, _>>();
260
261    Ok(FeatureMap {
262        features: feature_map,
263    })
264}
265
266/// Checks if python object is an instance of a Pydantic BaseModel
267/// # Arguments
268/// * `py` - Python interpreter instance
269/// * `obj` - Python object to check
270/// # Returns
271/// * `Ok(bool)` - `true` if the object is a Pydantic model
272/// * `Err(TypeError)` - if there was an error importing Pydantic or checking
273pub fn is_pydantic_model(py: Python, obj: &Bound<'_, PyAny>) -> Result<bool, TypeError> {
274    let pydantic = match py.import("pydantic") {
275        Ok(module) => module,
276        Err(e) => return Err(TypeError::FailedToImportPydantic(e.to_string())),
277    };
278    let basemodel = pydantic.getattr("BaseModel")?;
279
280    // check if context is a pydantic model
281    let is_basemodel = obj
282        .is_instance(&basemodel)
283        .map_err(|e| TypeError::FailedToCheckPydanticModel(e.to_string()))?;
284
285    Ok(is_basemodel)
286}
287
288#[derive(PartialEq, Debug)]
289pub struct ProfileArgs {
290    pub name: String,
291    pub space: String,
292    pub version: Option<String>,
293    pub schedule: String,
294    pub scouter_version: String,
295    pub drift_type: DriftType,
296}
297
298// trait to implement on all profile types
299pub trait ProfileBaseArgs {
300    fn get_base_args(&self) -> ProfileArgs;
301    fn to_value(&self) -> serde_json::Value;
302}
303
304pub trait ValidateAlertConfig {
305    fn resolve_schedule(schedule: &str) -> String {
306        let default_schedule = CommonCrons::EveryDay.cron();
307
308        cron::Schedule::from_str(schedule) // Pass by reference here
309            .map(|_| schedule) // If valid, return the schedule
310            .unwrap_or_else(|_| {
311                tracing::error!("Invalid cron schedule, using default schedule");
312                &default_schedule
313            })
314            .to_string()
315    }
316}
317
318#[pyclass(eq)]
319#[derive(PartialEq, Debug)]
320pub enum DataType {
321    Pandas,
322    Polars,
323    Numpy,
324    Arrow,
325    Unknown,
326    LLM,
327}
328
329impl Display for DataType {
330    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
331        match self {
332            DataType::Pandas => write!(f, "pandas"),
333            DataType::Polars => write!(f, "polars"),
334            DataType::Numpy => write!(f, "numpy"),
335            DataType::Arrow => write!(f, "arrow"),
336            DataType::Unknown => write!(f, "unknown"),
337            DataType::LLM => write!(f, "llm"),
338        }
339    }
340}
341
342impl DataType {
343    pub fn from_module_name(module_name: &str) -> Result<Self, TypeError> {
344        match module_name {
345            "pandas.core.frame.DataFrame" => Ok(DataType::Pandas),
346            "polars.dataframe.frame.DataFrame" => Ok(DataType::Polars),
347            "numpy.ndarray" => Ok(DataType::Numpy),
348            "pyarrow.lib.Table" => Ok(DataType::Arrow),
349            "scouter_drift.llm.LLMRecord" => Ok(DataType::LLM),
350            _ => Err(TypeError::InvalidDataType),
351        }
352    }
353}
354
355pub fn get_utc_datetime() -> DateTime<Utc> {
356    Utc::now()
357}
358
359#[pyclass(eq)]
360#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
361pub enum AlertThreshold {
362    Below,
363    Above,
364    Outside,
365}
366
367impl Display for AlertThreshold {
368    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
369        write!(f, "{self:?}")
370    }
371}
372
373#[pymethods]
374impl AlertThreshold {
375    #[staticmethod]
376    pub fn from_value(value: &str) -> Option<Self> {
377        match value.to_lowercase().as_str() {
378            "below" => Some(AlertThreshold::Below),
379            "above" => Some(AlertThreshold::Above),
380            "outside" => Some(AlertThreshold::Outside),
381            _ => None,
382        }
383    }
384
385    pub fn __str__(&self) -> String {
386        match self {
387            AlertThreshold::Below => "Below".to_string(),
388            AlertThreshold::Above => "Above".to_string(),
389            AlertThreshold::Outside => "Outside".to_string(),
390        }
391    }
392}
393
394#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
395pub enum Status {
396    #[default]
397    All,
398    Pending,
399    Processing,
400    Processed,
401    Failed,
402}
403
404impl Status {
405    pub fn as_str(&self) -> Option<&'static str> {
406        match self {
407            Status::All => None,
408            Status::Pending => Some("pending"),
409            Status::Processing => Some("processing"),
410            Status::Processed => Some("processed"),
411            Status::Failed => Some("failed"),
412        }
413    }
414}
415
416impl FromStr for Status {
417    type Err = TypeError;
418
419    fn from_str(s: &str) -> Result<Self, Self::Err> {
420        match s.to_lowercase().as_str() {
421            "all" => Ok(Status::All),
422            "pending" => Ok(Status::Pending),
423            "processing" => Ok(Status::Processing),
424            "processed" => Ok(Status::Processed),
425            "failed" => Ok(Status::Failed),
426            _ => Err(TypeError::InvalidStatusError(s.to_string())),
427        }
428    }
429}
430
431impl Display for Status {
432    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
433        match self {
434            Status::All => write!(f, "all"),
435            Status::Pending => write!(f, "pending"),
436            Status::Processing => write!(f, "processing"),
437            Status::Processed => write!(f, "processed"),
438            Status::Failed => write!(f, "failed"),
439        }
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    pub struct TestStruct;
448    impl ValidateAlertConfig for TestStruct {}
449
450    #[test]
451    fn test_resolve_schedule_base() {
452        let valid_schedule = "0 0 5 * * *"; // Every day at 5:00 AM
453
454        let result = TestStruct::resolve_schedule(valid_schedule);
455
456        assert_eq!(result, "0 0 5 * * *".to_string());
457
458        let invalid_schedule = "invalid_cron";
459
460        let default_schedule = CommonCrons::EveryDay.cron();
461
462        let result = TestStruct::resolve_schedule(invalid_schedule);
463
464        assert_eq!(result, default_schedule);
465    }
466}