scouter_dataframe/parquet/
psi.rs1use crate::error::DataFrameError;
2use crate::parquet::traits::ParquetFrame;
3use crate::parquet::utils::ParquetHelper;
4use crate::sql::helper::get_binned_psi_drift_records_query;
5use crate::storage::ObjectStore;
6use arrow::array::AsArray;
7use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
8use arrow_array::array::{
9 Int32Array, ListArray, StringArray, StructArray, TimestampNanosecondArray,
10};
11use arrow_array::types::{Float32Type, Int32Type};
12use arrow_array::Array;
13use arrow_array::RecordBatch;
14use async_trait::async_trait;
15use chrono::{DateTime, Utc};
16use datafusion::dataframe::DataFrame;
17use datafusion::prelude::SessionContext;
18use scouter_settings::ObjectStorageSettings;
19use scouter_types::{
20 psi::FeatureBinProportionResult, PsiRecord, ServerRecords, StorageType, ToDriftRecords,
21};
22use std::collections::BTreeMap;
23use std::sync::Arc;
24
25use super::types::BinnedTableName;
26pub struct PsiDataFrame {
27 schema: Arc<Schema>,
28 pub object_store: ObjectStore,
29}
30
31#[async_trait]
32impl ParquetFrame for PsiDataFrame {
33 fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
34 PsiDataFrame::new(storage_settings)
35 }
36
37 async fn get_dataframe(&self, records: ServerRecords) -> Result<DataFrame, DataFrameError> {
38 let records = records.to_psi_drift_records()?;
39 let batch = self.build_batch(records)?;
40
41 let ctx = self.object_store.get_session()?;
42
43 let df = ctx.read_batches(vec![batch])?;
44 Ok(df)
45 }
46
47 fn storage_root(&self) -> String {
48 self.object_store.storage_settings.canonicalized_path()
49 }
50
51 fn storage_type(&self) -> StorageType {
52 self.object_store.storage_settings.storage_type.clone()
53 }
54
55 fn get_session_context(&self) -> Result<SessionContext, DataFrameError> {
56 Ok(self.object_store.get_session()?)
57 }
58
59 fn get_binned_sql(
60 &self,
61 bin: &f64,
62 start_time: &DateTime<Utc>,
63 end_time: &DateTime<Utc>,
64 entity_id: &i32,
65 ) -> String {
66 get_binned_psi_drift_records_query(bin, start_time, end_time, entity_id)
67 }
68
69 fn table_name(&self) -> String {
70 BinnedTableName::Psi.to_string()
71 }
72}
73
74impl PsiDataFrame {
75 pub fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
76 let schema = Arc::new(Schema::new(vec![
77 Field::new(
78 "created_at",
79 DataType::Timestamp(TimeUnit::Nanosecond, None),
80 false,
81 ),
82 Field::new("entity_id", DataType::Int32, false),
83 Field::new("feature", DataType::Utf8, false),
84 Field::new("bin_id", DataType::Int32, false),
85 Field::new("bin_count", DataType::Int32, false),
86 ]));
87
88 let object_store = ObjectStore::new(storage_settings)?;
89
90 Ok(PsiDataFrame {
91 schema,
92 object_store,
93 })
94 }
95
96 fn build_batch(&self, records: Vec<PsiRecord>) -> Result<RecordBatch, DataFrameError> {
98 let created_at_array = TimestampNanosecondArray::from_iter_values(
99 records
100 .iter()
101 .map(|r| r.created_at.timestamp_nanos_opt().unwrap_or_default()),
102 );
103
104 let entity_id_array = Int32Array::from_iter_values(records.iter().map(|r| r.entity_id));
105 let feature_array =
106 StringArray::from_iter_values(records.iter().map(|r| r.feature.as_str()));
107
108 let bin_id_array = Int32Array::from_iter_values(records.iter().map(|r| r.bin_id));
109 let bin_count_array = Int32Array::from_iter_values(records.iter().map(|r| r.bin_count));
110
111 let batch = RecordBatch::try_new(
112 self.schema.clone(),
113 vec![
114 Arc::new(created_at_array),
115 Arc::new(entity_id_array),
116 Arc::new(feature_array),
117 Arc::new(bin_id_array),
118 Arc::new(bin_count_array),
119 ],
120 )?;
121
122 Ok(batch)
123 }
124}
125
126fn get_bin_proportions_struct(batch: &RecordBatch) -> Result<&ListArray, DataFrameError> {
128 batch
129 .column(2)
130 .as_any()
131 .downcast_ref::<ListArray>()
132 .ok_or_else(|| DataFrameError::GetColumnError("bin_proportions"))
133}
134
135fn get_bin_fields(structs: &StructArray) -> Result<(&ListArray, &ListArray), DataFrameError> {
137 let bin_ids = structs
138 .column_by_name("bin_id")
139 .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
140 .as_any()
141 .downcast_ref::<ListArray>()
142 .ok_or_else(|| DataFrameError::GetColumnError("bin_id"))?;
143
144 let proportions = structs
145 .column_by_name("proportion")
146 .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
147 .as_any()
148 .downcast_ref::<ListArray>()
149 .ok_or_else(|| DataFrameError::GetColumnError("proportion"))?;
150
151 Ok((bin_ids, proportions))
152}
153
154fn get_bin_ids(array: &dyn Array) -> Result<Vec<i32>, DataFrameError> {
156 Ok(array.as_primitive::<Int32Type>().iter().flatten().collect())
157}
158
159fn get_proportions(array: &dyn Array) -> Result<Vec<f64>, DataFrameError> {
162 Ok(array
163 .as_primitive::<Float32Type>()
164 .iter()
165 .filter_map(|p| p.map(|v| v as f64))
166 .collect())
167}
168
169fn create_bin_map(
171 bin_ids: &ListArray,
172 proportions: &ListArray,
173 index: usize,
174) -> Result<BTreeMap<i32, f64>, DataFrameError> {
175 let bin_ids = get_bin_ids(&bin_ids.value(index))?;
176 let proportions = get_proportions(&proportions.value(index))?;
177
178 Ok(bin_ids.into_iter().zip(proportions).collect())
179}
180
181fn extract_bin_proportions(batch: &RecordBatch) -> Result<Vec<BTreeMap<i32, f64>>, DataFrameError> {
183 let bin_structs = get_bin_proportions_struct(batch)?.value(0);
184 let bin_structs = bin_structs
185 .as_any()
186 .downcast_ref::<StructArray>()
187 .ok_or_else(|| DataFrameError::DowncastError("Bin structs"))?;
188
189 let (bin_ids_field, proportions_field) = get_bin_fields(bin_structs)?;
190
191 let mut bin_proportions = Vec::with_capacity(bin_structs.len());
192 for i in 0..bin_structs.len() {
193 let bin_map = create_bin_map(bin_ids_field, proportions_field, i)?;
194 bin_proportions.push(bin_map);
195 }
196
197 Ok(bin_proportions)
198}
199
200fn get_overall_proportions_struct(batch: &RecordBatch) -> Result<&StructArray, DataFrameError> {
202 let overall_proportions_struct = batch
203 .column(3)
204 .as_any()
205 .downcast_ref::<StructArray>()
206 .ok_or_else(|| DataFrameError::DowncastError("overall proportion struct"))?;
207
208 Ok(overall_proportions_struct)
209}
210
211fn get_overall_fields(
212 overall_struct: &StructArray,
213) -> Result<(&ListArray, &ListArray), DataFrameError> {
214 let overall_bin_ids = overall_struct
215 .column_by_name("bin_id")
216 .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
217 .as_any()
218 .downcast_ref::<ListArray>()
219 .ok_or_else(|| DataFrameError::DowncastError("bin_id"))?;
220
221 let overall_proportions = overall_struct
222 .column_by_name("proportion")
223 .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
224 .as_any()
225 .downcast_ref::<ListArray>()
226 .ok_or_else(|| DataFrameError::DowncastError("proporition"))?;
227
228 Ok((overall_bin_ids, overall_proportions))
229}
230
231fn extract_overall_proportions(batch: &RecordBatch) -> Result<BTreeMap<i32, f64>, DataFrameError> {
232 let overall_struct = get_overall_proportions_struct(batch)?;
233 let (bin_ids, proportions) = get_overall_fields(overall_struct)?;
234
235 let bin_ids = get_bin_ids(&bin_ids.value(0))?;
236 let proportions = get_proportions(&proportions.value(0))?;
237
238 Ok(bin_ids.into_iter().zip(proportions).collect())
239}
240
241fn process_psi_record_batch(
250 batch: &RecordBatch,
251) -> Result<FeatureBinProportionResult, DataFrameError> {
252 Ok(FeatureBinProportionResult {
253 feature: ParquetHelper::extract_feature_array(batch)?
254 .value(0)
255 .to_string(),
256 created_at: ParquetHelper::extract_created_at(batch)?,
257 bin_proportions: extract_bin_proportions(batch)?,
258 overall_proportions: extract_overall_proportions(batch)?,
259 })
260}
261
262pub async fn dataframe_to_psi_drift_features(
270 df: DataFrame,
271) -> Result<Vec<FeatureBinProportionResult>, DataFrameError> {
272 let batches = df.collect().await?;
273
274 batches
275 .into_iter()
276 .map(|batch| process_psi_record_batch(&batch))
277 .collect()
278}