scouter_dataframe/parquet/
psi.rs

1use crate::error::DataFrameError;
2use crate::parquet::traits::ParquetFrame;
3use crate::sql::helper::get_binned_psi_drift_records_query;
4use crate::storage::ObjectStore;
5use arrow::array::AsArray;
6use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
7use arrow_array::array::{
8    ListArray, StringArray, StringViewArray, StructArray, TimestampNanosecondArray, UInt64Array,
9};
10use arrow_array::types::{Float32Type, UInt64Type};
11use arrow_array::Array;
12use arrow_array::RecordBatch;
13use async_trait::async_trait;
14use chrono::{DateTime, TimeZone, Utc};
15use datafusion::dataframe::DataFrame;
16use datafusion::prelude::SessionContext;
17use scouter_settings::ObjectStorageSettings;
18use scouter_types::{
19    psi::FeatureBinProportionResult, PsiServerRecord, ServerRecords, StorageType, ToDriftRecords,
20};
21use std::collections::BTreeMap;
22use std::sync::Arc;
23
24use super::types::BinnedTableName;
25pub struct PsiDataFrame {
26    schema: Arc<Schema>,
27    pub object_store: ObjectStore,
28}
29
30#[async_trait]
31impl ParquetFrame for PsiDataFrame {
32    fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
33        PsiDataFrame::new(storage_settings)
34    }
35
36    async fn get_dataframe(&self, records: ServerRecords) -> Result<DataFrame, DataFrameError> {
37        let records = records.to_psi_drift_records()?;
38        let batch = self.build_batch(records)?;
39
40        let ctx = self.object_store.get_session()?;
41
42        let df = ctx.read_batches(vec![batch])?;
43        Ok(df)
44    }
45
46    fn storage_root(&self) -> String {
47        self.object_store.storage_settings.canonicalized_path()
48    }
49
50    fn storage_type(&self) -> StorageType {
51        self.object_store.storage_settings.storage_type.clone()
52    }
53
54    fn get_session_context(&self) -> Result<SessionContext, DataFrameError> {
55        Ok(self.object_store.get_session()?)
56    }
57
58    fn get_binned_sql(
59        &self,
60        bin: &f64,
61        start_time: &DateTime<Utc>,
62        end_time: &DateTime<Utc>,
63        space: &str,
64        name: &str,
65        version: &str,
66    ) -> String {
67        get_binned_psi_drift_records_query(bin, start_time, end_time, space, name, version)
68    }
69
70    fn table_name(&self) -> String {
71        BinnedTableName::Psi.to_string()
72    }
73}
74
75impl PsiDataFrame {
76    pub fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
77        let schema = Arc::new(Schema::new(vec![
78            Field::new(
79                "created_at",
80                DataType::Timestamp(TimeUnit::Nanosecond, None),
81                false,
82            ),
83            Field::new("space", DataType::Utf8, false),
84            Field::new("name", DataType::Utf8, false),
85            Field::new("version", DataType::Utf8, false),
86            Field::new("feature", DataType::Utf8, false),
87            Field::new("bin_id", DataType::UInt64, false),
88            Field::new("bin_count", DataType::UInt64, false),
89        ]));
90
91        let object_store = ObjectStore::new(storage_settings)?;
92
93        Ok(PsiDataFrame {
94            schema,
95            object_store,
96        })
97    }
98
99    /// Create and arrow RecordBatch from the given records
100    fn build_batch(&self, records: Vec<PsiServerRecord>) -> Result<RecordBatch, DataFrameError> {
101        let created_at_array = TimestampNanosecondArray::from_iter_values(
102            records
103                .iter()
104                .map(|r| r.created_at.timestamp_nanos_opt().unwrap_or_default()),
105        );
106
107        let space_array = StringArray::from_iter_values(records.iter().map(|r| r.space.as_str()));
108        let name_array = StringArray::from_iter_values(records.iter().map(|r| r.name.as_str()));
109        let version_array =
110            StringArray::from_iter_values(records.iter().map(|r| r.version.as_str()));
111        let feature_array =
112            StringArray::from_iter_values(records.iter().map(|r| r.feature.as_str()));
113
114        let bin_id_array = UInt64Array::from_iter_values(records.iter().map(|r| r.bin_id as u64));
115        let bin_count_array =
116            UInt64Array::from_iter_values(records.iter().map(|r| r.bin_count as u64));
117
118        let batch = RecordBatch::try_new(
119            self.schema.clone(),
120            vec![
121                Arc::new(created_at_array),
122                Arc::new(space_array),
123                Arc::new(name_array),
124                Arc::new(version_array),
125                Arc::new(feature_array),
126                Arc::new(bin_id_array),
127                Arc::new(bin_count_array),
128            ],
129        )?;
130
131        Ok(batch)
132    }
133}
134
135/// Extraction logic to get feature from a return record batch
136fn extract_feature(batch: &RecordBatch) -> Result<String, DataFrameError> {
137    let feature_array = batch
138        .column(0)
139        .as_any()
140        .downcast_ref::<StringViewArray>()
141        .ok_or_else(|| DataFrameError::GetColumnError("feature"))?;
142    Ok(feature_array.value(0).to_string())
143}
144
145/// Extraction logic to get created_at from a return record batch
146fn extract_created_at(batch: &RecordBatch) -> Result<Vec<DateTime<Utc>>, DataFrameError> {
147    let created_at_list = batch
148        .column(1)
149        .as_any()
150        .downcast_ref::<ListArray>()
151        .ok_or_else(|| DataFrameError::GetColumnError("created_at"))?;
152
153    let created_at_array = created_at_list.value(0);
154    Ok(created_at_array
155        .as_primitive::<arrow::datatypes::TimestampNanosecondType>()
156        .iter()
157        .filter_map(|ts| ts.map(|t| Utc.timestamp_nanos(t)))
158        .collect())
159}
160
161/// Extraction logic to get bin proportions from a return record batch
162fn get_bin_proportions_struct(batch: &RecordBatch) -> Result<&ListArray, DataFrameError> {
163    batch
164        .column(2)
165        .as_any()
166        .downcast_ref::<ListArray>()
167        .ok_or_else(|| DataFrameError::GetColumnError("bin_proportions"))
168}
169
170/// Extraction logic to get bin ids and proportions from a return record batch
171fn get_bin_fields(structs: &StructArray) -> Result<(&ListArray, &ListArray), DataFrameError> {
172    let bin_ids = structs
173        .column_by_name("bin_id")
174        .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
175        .as_any()
176        .downcast_ref::<ListArray>()
177        .ok_or_else(|| DataFrameError::GetColumnError("bin_id"))?;
178
179    let proportions = structs
180        .column_by_name("proportion")
181        .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
182        .as_any()
183        .downcast_ref::<ListArray>()
184        .ok_or_else(|| DataFrameError::GetColumnError("proportion"))?;
185
186    Ok((bin_ids, proportions))
187}
188
189/// Convert the bin id array to a Vec<usize>
190fn get_bin_ids(array: &dyn Array) -> Result<Vec<usize>, DataFrameError> {
191    Ok(array
192        .as_primitive::<UInt64Type>()
193        .iter()
194        .filter_map(|id| id.map(|i| i as usize))
195        .collect())
196}
197
198/// Convert the proportion array to a Vec<f64>
199/// TODO: Should we store f64 or f32?
200fn get_proportions(array: &dyn Array) -> Result<Vec<f64>, DataFrameError> {
201    Ok(array
202        .as_primitive::<Float32Type>()
203        .iter()
204        .filter_map(|p| p.map(|v| v as f64))
205        .collect())
206}
207
208/// Create a BTreeMap from the bin ids and proportions
209fn create_bin_map(
210    bin_ids: &ListArray,
211    proportions: &ListArray,
212    index: usize,
213) -> Result<BTreeMap<usize, f64>, DataFrameError> {
214    let bin_ids = get_bin_ids(&bin_ids.value(index))?;
215    let proportions = get_proportions(&proportions.value(index))?;
216
217    Ok(bin_ids.into_iter().zip(proportions).collect())
218}
219
220/// Extract bin proportions from a return record batch
221fn extract_bin_proportions(
222    batch: &RecordBatch,
223) -> Result<Vec<BTreeMap<usize, f64>>, DataFrameError> {
224    let bin_structs = get_bin_proportions_struct(batch)?.value(0);
225    let bin_structs = bin_structs
226        .as_any()
227        .downcast_ref::<StructArray>()
228        .ok_or_else(|| DataFrameError::DowncastError("Bin structs"))?;
229
230    let (bin_ids_field, proportions_field) = get_bin_fields(bin_structs)?;
231
232    let mut bin_proportions = Vec::with_capacity(bin_structs.len());
233    for i in 0..bin_structs.len() {
234        let bin_map = create_bin_map(bin_ids_field, proportions_field, i)?;
235        bin_proportions.push(bin_map);
236    }
237
238    Ok(bin_proportions)
239}
240
241/// Extract overall proportions from a return record batch
242fn get_overall_proportions_struct(batch: &RecordBatch) -> Result<&StructArray, DataFrameError> {
243    let overall_proportions_struct = batch
244        .column(3)
245        .as_any()
246        .downcast_ref::<StructArray>()
247        .ok_or_else(|| DataFrameError::DowncastError("overall proportion struct"))?;
248
249    Ok(overall_proportions_struct)
250}
251
252fn get_overall_fields(
253    overall_struct: &StructArray,
254) -> Result<(&ListArray, &ListArray), DataFrameError> {
255    let overall_bin_ids = overall_struct
256        .column_by_name("bin_id")
257        .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
258        .as_any()
259        .downcast_ref::<ListArray>()
260        .ok_or_else(|| DataFrameError::DowncastError("bin_id"))?;
261
262    let overall_proportions = overall_struct
263        .column_by_name("proportion")
264        .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
265        .as_any()
266        .downcast_ref::<ListArray>()
267        .ok_or_else(|| DataFrameError::DowncastError("proporition"))?;
268
269    Ok((overall_bin_ids, overall_proportions))
270}
271
272fn extract_overall_proportions(
273    batch: &RecordBatch,
274) -> Result<BTreeMap<usize, f64>, DataFrameError> {
275    let overall_struct = get_overall_proportions_struct(batch)?;
276    let (bin_ids, proportions) = get_overall_fields(overall_struct)?;
277
278    let bin_ids = get_bin_ids(&bin_ids.value(0))?;
279    let proportions = get_proportions(&proportions.value(0))?;
280
281    Ok(bin_ids.into_iter().zip(proportions).collect())
282}
283
284/// Helper function to process a record batch to feature and SpcDriftFeature
285///
286/// # Arguments
287/// * `batch` - The record batch to process
288/// * `features` - The features to populate
289///
290/// # Returns
291/// * `Result<(), DataFrameError>` - The result of the processing
292fn process_psi_record_batch(
293    batch: &RecordBatch,
294) -> Result<FeatureBinProportionResult, DataFrameError> {
295    Ok(FeatureBinProportionResult {
296        feature: extract_feature(batch)?,
297        created_at: extract_created_at(batch)?,
298        bin_proportions: extract_bin_proportions(batch)?,
299        overall_proportions: extract_overall_proportions(batch)?,
300    })
301}
302
303/// Convert a DataFrame to SpcDriftFeatures
304///
305/// # Arguments
306/// * `df` - The DataFrame to convert
307///
308/// # Returns
309/// * `SpcDriftFeatures` - The converted SpcDriftFeatures
310pub async fn dataframe_to_psi_drift_features(
311    df: DataFrame,
312) -> Result<Vec<FeatureBinProportionResult>, DataFrameError> {
313    let batches = df.collect().await?;
314
315    batches
316        .into_iter()
317        .map(|batch| process_psi_record_batch(&batch))
318        .collect()
319}