scouter_dataframe/parquet/
psi.rs

1use crate::error::DataFrameError;
2use crate::parquet::traits::ParquetFrame;
3use crate::parquet::utils::ParquetHelper;
4use crate::sql::helper::get_binned_psi_drift_records_query;
5use crate::storage::ObjectStore;
6use arrow::array::AsArray;
7use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
8use arrow_array::array::{
9    ListArray, StringArray, StructArray, TimestampNanosecondArray, UInt64Array,
10};
11use arrow_array::types::{Float32Type, UInt64Type};
12use arrow_array::Array;
13use arrow_array::RecordBatch;
14use async_trait::async_trait;
15use chrono::{DateTime, Utc};
16use datafusion::dataframe::DataFrame;
17use datafusion::prelude::SessionContext;
18use scouter_settings::ObjectStorageSettings;
19use scouter_types::{
20    psi::FeatureBinProportionResult, PsiServerRecord, ServerRecords, StorageType, ToDriftRecords,
21};
22use std::collections::BTreeMap;
23use std::sync::Arc;
24
25use super::types::BinnedTableName;
26pub struct PsiDataFrame {
27    schema: Arc<Schema>,
28    pub object_store: ObjectStore,
29}
30
31#[async_trait]
32impl ParquetFrame for PsiDataFrame {
33    fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
34        PsiDataFrame::new(storage_settings)
35    }
36
37    async fn get_dataframe(&self, records: ServerRecords) -> Result<DataFrame, DataFrameError> {
38        let records = records.to_psi_drift_records()?;
39        let batch = self.build_batch(records)?;
40
41        let ctx = self.object_store.get_session()?;
42
43        let df = ctx.read_batches(vec![batch])?;
44        Ok(df)
45    }
46
47    fn storage_root(&self) -> String {
48        self.object_store.storage_settings.canonicalized_path()
49    }
50
51    fn storage_type(&self) -> StorageType {
52        self.object_store.storage_settings.storage_type.clone()
53    }
54
55    fn get_session_context(&self) -> Result<SessionContext, DataFrameError> {
56        Ok(self.object_store.get_session()?)
57    }
58
59    fn get_binned_sql(
60        &self,
61        bin: &f64,
62        start_time: &DateTime<Utc>,
63        end_time: &DateTime<Utc>,
64        space: &str,
65        name: &str,
66        version: &str,
67    ) -> String {
68        get_binned_psi_drift_records_query(bin, start_time, end_time, space, name, version)
69    }
70
71    fn table_name(&self) -> String {
72        BinnedTableName::Psi.to_string()
73    }
74}
75
76impl PsiDataFrame {
77    pub fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
78        let schema = Arc::new(Schema::new(vec![
79            Field::new(
80                "created_at",
81                DataType::Timestamp(TimeUnit::Nanosecond, None),
82                false,
83            ),
84            Field::new("space", DataType::Utf8, false),
85            Field::new("name", DataType::Utf8, false),
86            Field::new("version", DataType::Utf8, false),
87            Field::new("feature", DataType::Utf8, false),
88            Field::new("bin_id", DataType::UInt64, false),
89            Field::new("bin_count", DataType::UInt64, false),
90        ]));
91
92        let object_store = ObjectStore::new(storage_settings)?;
93
94        Ok(PsiDataFrame {
95            schema,
96            object_store,
97        })
98    }
99
100    /// Create and arrow RecordBatch from the given records
101    fn build_batch(&self, records: Vec<PsiServerRecord>) -> Result<RecordBatch, DataFrameError> {
102        let created_at_array = TimestampNanosecondArray::from_iter_values(
103            records
104                .iter()
105                .map(|r| r.created_at.timestamp_nanos_opt().unwrap_or_default()),
106        );
107
108        let space_array = StringArray::from_iter_values(records.iter().map(|r| r.space.as_str()));
109        let name_array = StringArray::from_iter_values(records.iter().map(|r| r.name.as_str()));
110        let version_array =
111            StringArray::from_iter_values(records.iter().map(|r| r.version.as_str()));
112        let feature_array =
113            StringArray::from_iter_values(records.iter().map(|r| r.feature.as_str()));
114
115        let bin_id_array = UInt64Array::from_iter_values(records.iter().map(|r| r.bin_id as u64));
116        let bin_count_array =
117            UInt64Array::from_iter_values(records.iter().map(|r| r.bin_count as u64));
118
119        let batch = RecordBatch::try_new(
120            self.schema.clone(),
121            vec![
122                Arc::new(created_at_array),
123                Arc::new(space_array),
124                Arc::new(name_array),
125                Arc::new(version_array),
126                Arc::new(feature_array),
127                Arc::new(bin_id_array),
128                Arc::new(bin_count_array),
129            ],
130        )?;
131
132        Ok(batch)
133    }
134}
135
136/// Extraction logic to get bin proportions from a return record batch
137fn get_bin_proportions_struct(batch: &RecordBatch) -> Result<&ListArray, DataFrameError> {
138    batch
139        .column(2)
140        .as_any()
141        .downcast_ref::<ListArray>()
142        .ok_or_else(|| DataFrameError::GetColumnError("bin_proportions"))
143}
144
145/// Extraction logic to get bin ids and proportions from a return record batch
146fn get_bin_fields(structs: &StructArray) -> Result<(&ListArray, &ListArray), DataFrameError> {
147    let bin_ids = structs
148        .column_by_name("bin_id")
149        .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
150        .as_any()
151        .downcast_ref::<ListArray>()
152        .ok_or_else(|| DataFrameError::GetColumnError("bin_id"))?;
153
154    let proportions = structs
155        .column_by_name("proportion")
156        .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
157        .as_any()
158        .downcast_ref::<ListArray>()
159        .ok_or_else(|| DataFrameError::GetColumnError("proportion"))?;
160
161    Ok((bin_ids, proportions))
162}
163
164/// Convert the bin id array to a Vec<usize>
165fn get_bin_ids(array: &dyn Array) -> Result<Vec<usize>, DataFrameError> {
166    Ok(array
167        .as_primitive::<UInt64Type>()
168        .iter()
169        .filter_map(|id| id.map(|i| i as usize))
170        .collect())
171}
172
173/// Convert the proportion array to a Vec<f64>
174/// TODO: Should we store f64 or f32?
175fn get_proportions(array: &dyn Array) -> Result<Vec<f64>, DataFrameError> {
176    Ok(array
177        .as_primitive::<Float32Type>()
178        .iter()
179        .filter_map(|p| p.map(|v| v as f64))
180        .collect())
181}
182
183/// Create a BTreeMap from the bin ids and proportions
184fn create_bin_map(
185    bin_ids: &ListArray,
186    proportions: &ListArray,
187    index: usize,
188) -> Result<BTreeMap<usize, f64>, DataFrameError> {
189    let bin_ids = get_bin_ids(&bin_ids.value(index))?;
190    let proportions = get_proportions(&proportions.value(index))?;
191
192    Ok(bin_ids.into_iter().zip(proportions).collect())
193}
194
195/// Extract bin proportions from a return record batch
196fn extract_bin_proportions(
197    batch: &RecordBatch,
198) -> Result<Vec<BTreeMap<usize, f64>>, DataFrameError> {
199    let bin_structs = get_bin_proportions_struct(batch)?.value(0);
200    let bin_structs = bin_structs
201        .as_any()
202        .downcast_ref::<StructArray>()
203        .ok_or_else(|| DataFrameError::DowncastError("Bin structs"))?;
204
205    let (bin_ids_field, proportions_field) = get_bin_fields(bin_structs)?;
206
207    let mut bin_proportions = Vec::with_capacity(bin_structs.len());
208    for i in 0..bin_structs.len() {
209        let bin_map = create_bin_map(bin_ids_field, proportions_field, i)?;
210        bin_proportions.push(bin_map);
211    }
212
213    Ok(bin_proportions)
214}
215
216/// Extract overall proportions from a return record batch
217fn get_overall_proportions_struct(batch: &RecordBatch) -> Result<&StructArray, DataFrameError> {
218    let overall_proportions_struct = batch
219        .column(3)
220        .as_any()
221        .downcast_ref::<StructArray>()
222        .ok_or_else(|| DataFrameError::DowncastError("overall proportion struct"))?;
223
224    Ok(overall_proportions_struct)
225}
226
227fn get_overall_fields(
228    overall_struct: &StructArray,
229) -> Result<(&ListArray, &ListArray), DataFrameError> {
230    let overall_bin_ids = overall_struct
231        .column_by_name("bin_id")
232        .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
233        .as_any()
234        .downcast_ref::<ListArray>()
235        .ok_or_else(|| DataFrameError::DowncastError("bin_id"))?;
236
237    let overall_proportions = overall_struct
238        .column_by_name("proportion")
239        .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
240        .as_any()
241        .downcast_ref::<ListArray>()
242        .ok_or_else(|| DataFrameError::DowncastError("proporition"))?;
243
244    Ok((overall_bin_ids, overall_proportions))
245}
246
247fn extract_overall_proportions(
248    batch: &RecordBatch,
249) -> Result<BTreeMap<usize, f64>, DataFrameError> {
250    let overall_struct = get_overall_proportions_struct(batch)?;
251    let (bin_ids, proportions) = get_overall_fields(overall_struct)?;
252
253    let bin_ids = get_bin_ids(&bin_ids.value(0))?;
254    let proportions = get_proportions(&proportions.value(0))?;
255
256    Ok(bin_ids.into_iter().zip(proportions).collect())
257}
258
259/// Helper function to process a record batch to feature and SpcDriftFeature
260///
261/// # Arguments
262/// * `batch` - The record batch to process
263/// * `features` - The features to populate
264///
265/// # Returns
266/// * `Result<(), DataFrameError>` - The result of the processing
267fn process_psi_record_batch(
268    batch: &RecordBatch,
269) -> Result<FeatureBinProportionResult, DataFrameError> {
270    Ok(FeatureBinProportionResult {
271        feature: ParquetHelper::extract_feature_array(batch)?
272            .value(0)
273            .to_string(),
274        created_at: ParquetHelper::extract_created_at(batch)?,
275        bin_proportions: extract_bin_proportions(batch)?,
276        overall_proportions: extract_overall_proportions(batch)?,
277    })
278}
279
280/// Convert a DataFrame to SpcDriftFeatures
281///
282/// # Arguments
283/// * `df` - The DataFrame to convert
284///
285/// # Returns
286/// * `SpcDriftFeatures` - The converted SpcDriftFeatures
287pub async fn dataframe_to_psi_drift_features(
288    df: DataFrame,
289) -> Result<Vec<FeatureBinProportionResult>, DataFrameError> {
290    let batches = df.collect().await?;
291
292    batches
293        .into_iter()
294        .map(|batch| process_psi_record_batch(&batch))
295        .collect()
296}