scouter_client/data_utils/
polars.rs

1use crate::data_utils::types::DataTypes;
2use crate::data_utils::{ConvertedData, DataConverter};
3use crate::error::DataError;
4use pyo3::prelude::*;
5pub struct PolarsDataConverter;
6
7impl DataConverter for PolarsDataConverter {
8    fn categorize_features<'py>(
9        py: Python<'py>,
10        data: &Bound<'py, PyAny>,
11    ) -> Result<DataTypes, DataError> {
12        let cs = py.import("polars")?.getattr("selectors")?;
13
14        let columns = data.getattr("columns")?.extract::<Vec<String>>()?;
15
16        let integer_features = data
17            .call_method1("select", (&cs.call_method0("integer")?,))?
18            .getattr("columns")?
19            .extract::<Vec<String>>()?;
20
21        let float_features = data
22            .call_method1("select", (&cs.call_method0("float")?,))?
23            .getattr("columns")?
24            .extract::<Vec<String>>()?;
25
26        let string_features = columns
27            .iter()
28            .filter(|col| !float_features.contains(col) && !integer_features.contains(col))
29            .cloned()
30            .collect();
31
32        Ok(DataTypes::new(
33            integer_features,
34            float_features,
35            string_features,
36        ))
37    }
38
39    #[allow(clippy::needless_lifetimes)]
40    fn process_numeric_features<'py>(
41        data: &Bound<'py, PyAny>,
42        data_types: &DataTypes,
43    ) -> Result<(Option<Bound<'py, PyAny>>, Option<String>), DataError> {
44        if data_types.numeric_features.is_empty() {
45            return Ok((None, None));
46        }
47
48        // If mixed types, we cast to Float64 to ensure consistency
49        let array = if data_types.has_mixed_types() {
50            let py = data.py();
51            let float64 = py.import("polars")?.getattr("Float64")?;
52
53            data.get_item(&data_types.numeric_features)?
54                .call_method1("cast", (float64,))?
55                .call_method0("to_numpy")?
56        } else {
57            data.get_item(&data_types.numeric_features)?
58                .call_method0("to_numpy")?
59        };
60
61        let dtype = Some(array.getattr("dtype")?.str()?.to_string());
62
63        Ok((Some(array), dtype))
64    }
65
66    #[allow(clippy::needless_lifetimes)]
67    fn process_string_features<'py>(
68        data: &Bound<'py, PyAny>,
69        features: &[String],
70    ) -> Result<Option<Vec<Vec<String>>>, DataError> {
71        if features.is_empty() {
72            return Ok(None);
73        }
74
75        let py = data.py();
76        let polars = py.import("polars")?;
77        let pl_string = polars.getattr("String")?;
78
79        Ok(Some(
80            features
81                .iter()
82                .map(|feature| {
83                    let array = data
84                        .get_item(feature)?
85                        .call_method1("cast", (pl_string.clone(),))?
86                        .call_method0("to_list")?
87                        .extract::<Vec<String>>()?;
88                    Ok(array)
89                })
90                .collect::<Result<Vec<Vec<String>>, DataError>>()?,
91        ))
92    }
93
94    fn prepare_data<'py>(
95        py: Python<'py>,
96        data: &Bound<'py, PyAny>,
97    ) -> Result<ConvertedData<'py>, DataError> {
98        let data_types = PolarsDataConverter::categorize_features(py, data)?;
99
100        let (numeric_array, dtype) =
101            PolarsDataConverter::process_numeric_features(data, &data_types)?;
102        let string_array =
103            PolarsDataConverter::process_string_features(data, &data_types.string_features)?;
104
105        Ok((
106            data_types.numeric_features,
107            numeric_array,
108            dtype,
109            data_types.string_features,
110            string_array,
111        ))
112    }
113}