scouter_client/data_utils/
arrow.rs

1use crate::data_utils::{ConvertedData, DataConverter, DataTypes};
2use crate::error::DataError;
3use pyo3::prelude::*;
4
5pub struct ArrowDataConverter;
6
7impl DataConverter for ArrowDataConverter {
8    #[allow(clippy::if_same_then_else)]
9    fn categorize_features<'py>(
10        py: Python<'py>,
11        data: &Bound<'py, PyAny>,
12    ) -> Result<DataTypes, DataError> {
13        let mut string_features = Vec::new();
14        let mut integer_features = Vec::new();
15        let mut float_features = Vec::new();
16        let features = data.getattr("column_names")?.extract::<Vec<String>>()?;
17        let schema = data.getattr("schema")?;
18
19        for feature in features {
20            let dtype = schema.call_method1("field", (&feature,))?.getattr("type")?;
21            // assert dtype does not in [pa.int8(), pa.int16(), pa.int32(), pa.int64(), pa.float32(), pa.float64()]
22            let pa_types = py.import("pyarrow")?.getattr("types")?;
23
24            if pa_types
25                .call_method1("is_integer", (&dtype,))?
26                .extract::<bool>()?
27            {
28                integer_features.push(feature);
29            } else if pa_types
30                .call_method1("is_floating", (&dtype,))?
31                .extract::<bool>()?
32            {
33                float_features.push(feature);
34            } else if pa_types
35                .call_method1("is_decimal", (&dtype,))?
36                .extract::<bool>()?
37            {
38                float_features.push(feature);
39            } else {
40                string_features.push(feature);
41            }
42        }
43
44        Ok(DataTypes::new(
45            integer_features,
46            float_features,
47            string_features,
48        ))
49    }
50
51    fn process_numeric_features<'py>(
52        data: &Bound<'py, PyAny>,
53        data_types: &DataTypes,
54    ) -> Result<(Option<Bound<'py, PyAny>>, Option<String>), DataError> {
55        let py = data.py();
56        if data_types.numeric_features.is_empty() {
57            return Ok((None, None));
58        }
59
60        let is_mixed_type = data_types.has_mixed_types();
61
62        let array = data_types
63            .numeric_features
64            .iter()
65            .map(|feature| {
66                let array = data
67                    .call_method1("column", (&feature,))?
68                    .call_method0("to_numpy")?;
69
70                // Convert all to f64
71                if is_mixed_type {
72                    Ok(array.call_method1("astype", ("float64",))?)
73                } else {
74                    Ok(array)
75                }
76            })
77            .collect::<Result<Vec<Bound<'py, PyAny>>, DataError>>()?;
78
79        let numpy = py.import("numpy")?;
80
81        // call numpy.column_stack on array
82        let array = numpy.call_method1("column_stack", (array,))?;
83        let dtype = Some(array.getattr("dtype")?.str()?.to_string());
84
85        Ok((Some(array), dtype))
86    }
87
88    #[allow(clippy::needless_lifetimes)]
89    fn process_string_features<'py>(
90        data: &Bound<'py, PyAny>,
91        features: &[String],
92    ) -> Result<Option<Vec<Vec<String>>>, DataError> {
93        if features.is_empty() {
94            return Ok(None);
95        }
96
97        let array = features
98            .iter()
99            .map(|feature| {
100                let array = data
101                    .call_method1("column", (&feature,))?
102                    .call_method0("to_pylist")?
103                    .extract::<Vec<String>>()?;
104                Ok(array)
105            })
106            .collect::<Result<Vec<Vec<String>>, DataError>>()?;
107        Ok(Some(array))
108    }
109
110    fn prepare_data<'py>(
111        py: Python<'py>,
112        data: &Bound<'py, PyAny>,
113    ) -> Result<ConvertedData<'py>, DataError> {
114        let data_types = ArrowDataConverter::categorize_features(py, data)?;
115
116        let (numeric_array, dtype) =
117            ArrowDataConverter::process_numeric_features(data, &data_types)?;
118        let string_array =
119            ArrowDataConverter::process_string_features(data, &data_types.string_features)?;
120
121        Ok((
122            data_types.numeric_features,
123            numeric_array,
124            dtype,
125            data_types.string_features,
126            string_array,
127        ))
128    }
129}