Skip to main content

scouter_dataframe/parquet/
psi.rs

1use crate::error::DataFrameError;
2use crate::parquet::traits::ParquetFrame;
3use crate::parquet::utils::ParquetHelper;
4use crate::sql::helper::get_binned_psi_drift_records_query;
5use crate::storage::ObjectStore;
6use arrow::array::AsArray;
7use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
8use arrow_array::array::{
9    Int32Array, ListArray, StringArray, StructArray, TimestampNanosecondArray,
10};
11use arrow_array::types::{Float32Type, Int32Type};
12use arrow_array::Array;
13use arrow_array::RecordBatch;
14use async_trait::async_trait;
15use chrono::{DateTime, Utc};
16use datafusion::dataframe::DataFrame;
17use datafusion::prelude::SessionContext;
18use scouter_settings::ObjectStorageSettings;
19use scouter_types::{
20    psi::FeatureBinProportionResult, PsiRecord, ServerRecords, StorageType, ToDriftRecords,
21};
22use std::collections::BTreeMap;
23use std::sync::Arc;
24
25use super::types::BinnedTableName;
26pub struct PsiDataFrame {
27    schema: Arc<Schema>,
28    pub object_store: ObjectStore,
29}
30
31#[async_trait]
32impl ParquetFrame for PsiDataFrame {
33    fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
34        PsiDataFrame::new(storage_settings)
35    }
36
37    async fn get_dataframe(&self, records: ServerRecords) -> Result<DataFrame, DataFrameError> {
38        let records = records.to_psi_drift_records()?;
39        let batch = self.build_batch(records)?;
40
41        let ctx = self.object_store.get_session()?;
42
43        let df = ctx.read_batches(vec![batch])?;
44        Ok(df)
45    }
46
47    fn storage_root(&self) -> String {
48        self.object_store.storage_settings.canonicalized_path()
49    }
50
51    fn storage_type(&self) -> StorageType {
52        self.object_store.storage_settings.storage_type.clone()
53    }
54
55    fn get_session_context(&self) -> Result<SessionContext, DataFrameError> {
56        Ok(self.object_store.get_session()?)
57    }
58
59    fn get_binned_sql(
60        &self,
61        bin: &f64,
62        start_time: &DateTime<Utc>,
63        end_time: &DateTime<Utc>,
64        entity_id: &i32,
65    ) -> String {
66        get_binned_psi_drift_records_query(bin, start_time, end_time, entity_id)
67    }
68
69    fn table_name(&self) -> String {
70        BinnedTableName::Psi.to_string()
71    }
72}
73
74impl PsiDataFrame {
75    pub fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
76        let schema = Arc::new(Schema::new(vec![
77            Field::new(
78                "created_at",
79                DataType::Timestamp(TimeUnit::Nanosecond, None),
80                false,
81            ),
82            Field::new("entity_id", DataType::Int32, false),
83            Field::new("feature", DataType::Utf8, false),
84            Field::new("bin_id", DataType::Int32, false),
85            Field::new("bin_count", DataType::Int32, false),
86        ]));
87
88        let object_store = ObjectStore::new(storage_settings)?;
89
90        Ok(PsiDataFrame {
91            schema,
92            object_store,
93        })
94    }
95
96    /// Create and arrow RecordBatch from the given records
97    fn build_batch(&self, records: Vec<PsiRecord>) -> Result<RecordBatch, DataFrameError> {
98        let created_at_array = TimestampNanosecondArray::from_iter_values(
99            records
100                .iter()
101                .map(|r| r.created_at.timestamp_nanos_opt().unwrap_or_default()),
102        );
103
104        let entity_id_array = Int32Array::from_iter_values(records.iter().map(|r| r.entity_id));
105        let feature_array =
106            StringArray::from_iter_values(records.iter().map(|r| r.feature.as_str()));
107
108        let bin_id_array = Int32Array::from_iter_values(records.iter().map(|r| r.bin_id));
109        let bin_count_array = Int32Array::from_iter_values(records.iter().map(|r| r.bin_count));
110
111        let batch = RecordBatch::try_new(
112            self.schema.clone(),
113            vec![
114                Arc::new(created_at_array),
115                Arc::new(entity_id_array),
116                Arc::new(feature_array),
117                Arc::new(bin_id_array),
118                Arc::new(bin_count_array),
119            ],
120        )?;
121
122        Ok(batch)
123    }
124}
125
126/// Extraction logic to get bin proportions from a return record batch
127fn get_bin_proportions_struct(batch: &RecordBatch) -> Result<&ListArray, DataFrameError> {
128    batch
129        .column(2)
130        .as_any()
131        .downcast_ref::<ListArray>()
132        .ok_or_else(|| DataFrameError::GetColumnError("bin_proportions"))
133}
134
135/// Extraction logic to get bin ids and proportions from a return record batch
136fn get_bin_fields(structs: &StructArray) -> Result<(&ListArray, &ListArray), DataFrameError> {
137    let bin_ids = structs
138        .column_by_name("bin_id")
139        .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
140        .as_any()
141        .downcast_ref::<ListArray>()
142        .ok_or_else(|| DataFrameError::GetColumnError("bin_id"))?;
143
144    let proportions = structs
145        .column_by_name("proportion")
146        .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
147        .as_any()
148        .downcast_ref::<ListArray>()
149        .ok_or_else(|| DataFrameError::GetColumnError("proportion"))?;
150
151    Ok((bin_ids, proportions))
152}
153
154/// Convert the bin id array to a Vec<usize>
155fn get_bin_ids(array: &dyn Array) -> Result<Vec<i32>, DataFrameError> {
156    Ok(array.as_primitive::<Int32Type>().iter().flatten().collect())
157}
158
159/// Convert the proportion array to a Vec<f64>
160/// TODO: Should we store f64 or f32?
161fn get_proportions(array: &dyn Array) -> Result<Vec<f64>, DataFrameError> {
162    Ok(array
163        .as_primitive::<Float32Type>()
164        .iter()
165        .filter_map(|p| p.map(|v| v as f64))
166        .collect())
167}
168
169/// Create a BTreeMap from the bin ids and proportions
170fn create_bin_map(
171    bin_ids: &ListArray,
172    proportions: &ListArray,
173    index: usize,
174) -> Result<BTreeMap<i32, f64>, DataFrameError> {
175    let bin_ids = get_bin_ids(&bin_ids.value(index))?;
176    let proportions = get_proportions(&proportions.value(index))?;
177
178    Ok(bin_ids.into_iter().zip(proportions).collect())
179}
180
181/// Extract bin proportions from a return record batch
182fn extract_bin_proportions(batch: &RecordBatch) -> Result<Vec<BTreeMap<i32, f64>>, DataFrameError> {
183    let bin_structs = get_bin_proportions_struct(batch)?.value(0);
184    let bin_structs = bin_structs
185        .as_any()
186        .downcast_ref::<StructArray>()
187        .ok_or_else(|| DataFrameError::DowncastError("Bin structs"))?;
188
189    let (bin_ids_field, proportions_field) = get_bin_fields(bin_structs)?;
190
191    let mut bin_proportions = Vec::with_capacity(bin_structs.len());
192    for i in 0..bin_structs.len() {
193        let bin_map = create_bin_map(bin_ids_field, proportions_field, i)?;
194        bin_proportions.push(bin_map);
195    }
196
197    Ok(bin_proportions)
198}
199
200/// Extract overall proportions from a return record batch
201fn get_overall_proportions_struct(batch: &RecordBatch) -> Result<&StructArray, DataFrameError> {
202    let overall_proportions_struct = batch
203        .column(3)
204        .as_any()
205        .downcast_ref::<StructArray>()
206        .ok_or_else(|| DataFrameError::DowncastError("overall proportion struct"))?;
207
208    Ok(overall_proportions_struct)
209}
210
211fn get_overall_fields(
212    overall_struct: &StructArray,
213) -> Result<(&ListArray, &ListArray), DataFrameError> {
214    let overall_bin_ids = overall_struct
215        .column_by_name("bin_id")
216        .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
217        .as_any()
218        .downcast_ref::<ListArray>()
219        .ok_or_else(|| DataFrameError::DowncastError("bin_id"))?;
220
221    let overall_proportions = overall_struct
222        .column_by_name("proportion")
223        .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
224        .as_any()
225        .downcast_ref::<ListArray>()
226        .ok_or_else(|| DataFrameError::DowncastError("proporition"))?;
227
228    Ok((overall_bin_ids, overall_proportions))
229}
230
231fn extract_overall_proportions(batch: &RecordBatch) -> Result<BTreeMap<i32, f64>, DataFrameError> {
232    let overall_struct = get_overall_proportions_struct(batch)?;
233    let (bin_ids, proportions) = get_overall_fields(overall_struct)?;
234
235    let bin_ids = get_bin_ids(&bin_ids.value(0))?;
236    let proportions = get_proportions(&proportions.value(0))?;
237
238    Ok(bin_ids.into_iter().zip(proportions).collect())
239}
240
241/// Helper function to process a record batch to feature and SpcDriftFeature
242///
243/// # Arguments
244/// * `batch` - The record batch to process
245/// * `features` - The features to populate
246///
247/// # Returns
248/// * `Result<(), DataFrameError>` - The result of the processing
249fn process_psi_record_batch(
250    batch: &RecordBatch,
251) -> Result<FeatureBinProportionResult, DataFrameError> {
252    Ok(FeatureBinProportionResult {
253        feature: ParquetHelper::extract_feature_array(batch)?
254            .value(0)
255            .to_string(),
256        created_at: ParquetHelper::extract_created_at(batch)?,
257        bin_proportions: extract_bin_proportions(batch)?,
258        overall_proportions: extract_overall_proportions(batch)?,
259    })
260}
261
262/// Convert a DataFrame to SpcDriftFeatures
263///
264/// # Arguments
265/// * `df` - The DataFrame to convert
266///
267/// # Returns
268/// * `SpcDriftFeatures` - The converted SpcDriftFeatures
269pub async fn dataframe_to_psi_drift_features(
270    df: DataFrame,
271) -> Result<Vec<FeatureBinProportionResult>, DataFrameError> {
272    let batches = df.collect().await?;
273
274    batches
275        .into_iter()
276        .map(|batch| process_psi_record_batch(&batch))
277        .collect()
278}