scouter_dataframe/parquet/
psi.rs1use crate::error::DataFrameError;
2use crate::parquet::traits::ParquetFrame;
3use crate::parquet::utils::ParquetHelper;
4use crate::sql::helper::get_binned_psi_drift_records_query;
5use crate::storage::ObjectStore;
6use arrow::array::AsArray;
7use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
8use arrow_array::array::{
9 ListArray, StringArray, StructArray, TimestampNanosecondArray, UInt64Array,
10};
11use arrow_array::types::{Float32Type, UInt64Type};
12use arrow_array::Array;
13use arrow_array::RecordBatch;
14use async_trait::async_trait;
15use chrono::{DateTime, Utc};
16use datafusion::dataframe::DataFrame;
17use datafusion::prelude::SessionContext;
18use scouter_settings::ObjectStorageSettings;
19use scouter_types::{
20 psi::FeatureBinProportionResult, PsiServerRecord, ServerRecords, StorageType, ToDriftRecords,
21};
22use std::collections::BTreeMap;
23use std::sync::Arc;
24
25use super::types::BinnedTableName;
26pub struct PsiDataFrame {
27 schema: Arc<Schema>,
28 pub object_store: ObjectStore,
29}
30
31#[async_trait]
32impl ParquetFrame for PsiDataFrame {
33 fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
34 PsiDataFrame::new(storage_settings)
35 }
36
37 async fn get_dataframe(&self, records: ServerRecords) -> Result<DataFrame, DataFrameError> {
38 let records = records.to_psi_drift_records()?;
39 let batch = self.build_batch(records)?;
40
41 let ctx = self.object_store.get_session()?;
42
43 let df = ctx.read_batches(vec![batch])?;
44 Ok(df)
45 }
46
47 fn storage_root(&self) -> String {
48 self.object_store.storage_settings.canonicalized_path()
49 }
50
51 fn storage_type(&self) -> StorageType {
52 self.object_store.storage_settings.storage_type.clone()
53 }
54
55 fn get_session_context(&self) -> Result<SessionContext, DataFrameError> {
56 Ok(self.object_store.get_session()?)
57 }
58
59 fn get_binned_sql(
60 &self,
61 bin: &f64,
62 start_time: &DateTime<Utc>,
63 end_time: &DateTime<Utc>,
64 space: &str,
65 name: &str,
66 version: &str,
67 ) -> String {
68 get_binned_psi_drift_records_query(bin, start_time, end_time, space, name, version)
69 }
70
71 fn table_name(&self) -> String {
72 BinnedTableName::Psi.to_string()
73 }
74}
75
76impl PsiDataFrame {
77 pub fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
78 let schema = Arc::new(Schema::new(vec![
79 Field::new(
80 "created_at",
81 DataType::Timestamp(TimeUnit::Nanosecond, None),
82 false,
83 ),
84 Field::new("space", DataType::Utf8, false),
85 Field::new("name", DataType::Utf8, false),
86 Field::new("version", DataType::Utf8, false),
87 Field::new("feature", DataType::Utf8, false),
88 Field::new("bin_id", DataType::UInt64, false),
89 Field::new("bin_count", DataType::UInt64, false),
90 ]));
91
92 let object_store = ObjectStore::new(storage_settings)?;
93
94 Ok(PsiDataFrame {
95 schema,
96 object_store,
97 })
98 }
99
100 fn build_batch(&self, records: Vec<PsiServerRecord>) -> Result<RecordBatch, DataFrameError> {
102 let created_at_array = TimestampNanosecondArray::from_iter_values(
103 records
104 .iter()
105 .map(|r| r.created_at.timestamp_nanos_opt().unwrap_or_default()),
106 );
107
108 let space_array = StringArray::from_iter_values(records.iter().map(|r| r.space.as_str()));
109 let name_array = StringArray::from_iter_values(records.iter().map(|r| r.name.as_str()));
110 let version_array =
111 StringArray::from_iter_values(records.iter().map(|r| r.version.as_str()));
112 let feature_array =
113 StringArray::from_iter_values(records.iter().map(|r| r.feature.as_str()));
114
115 let bin_id_array = UInt64Array::from_iter_values(records.iter().map(|r| r.bin_id as u64));
116 let bin_count_array =
117 UInt64Array::from_iter_values(records.iter().map(|r| r.bin_count as u64));
118
119 let batch = RecordBatch::try_new(
120 self.schema.clone(),
121 vec![
122 Arc::new(created_at_array),
123 Arc::new(space_array),
124 Arc::new(name_array),
125 Arc::new(version_array),
126 Arc::new(feature_array),
127 Arc::new(bin_id_array),
128 Arc::new(bin_count_array),
129 ],
130 )?;
131
132 Ok(batch)
133 }
134}
135
136fn get_bin_proportions_struct(batch: &RecordBatch) -> Result<&ListArray, DataFrameError> {
138 batch
139 .column(2)
140 .as_any()
141 .downcast_ref::<ListArray>()
142 .ok_or_else(|| DataFrameError::GetColumnError("bin_proportions"))
143}
144
145fn get_bin_fields(structs: &StructArray) -> Result<(&ListArray, &ListArray), DataFrameError> {
147 let bin_ids = structs
148 .column_by_name("bin_id")
149 .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
150 .as_any()
151 .downcast_ref::<ListArray>()
152 .ok_or_else(|| DataFrameError::GetColumnError("bin_id"))?;
153
154 let proportions = structs
155 .column_by_name("proportion")
156 .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
157 .as_any()
158 .downcast_ref::<ListArray>()
159 .ok_or_else(|| DataFrameError::GetColumnError("proportion"))?;
160
161 Ok((bin_ids, proportions))
162}
163
164fn get_bin_ids(array: &dyn Array) -> Result<Vec<usize>, DataFrameError> {
166 Ok(array
167 .as_primitive::<UInt64Type>()
168 .iter()
169 .filter_map(|id| id.map(|i| i as usize))
170 .collect())
171}
172
173fn get_proportions(array: &dyn Array) -> Result<Vec<f64>, DataFrameError> {
176 Ok(array
177 .as_primitive::<Float32Type>()
178 .iter()
179 .filter_map(|p| p.map(|v| v as f64))
180 .collect())
181}
182
183fn create_bin_map(
185 bin_ids: &ListArray,
186 proportions: &ListArray,
187 index: usize,
188) -> Result<BTreeMap<usize, f64>, DataFrameError> {
189 let bin_ids = get_bin_ids(&bin_ids.value(index))?;
190 let proportions = get_proportions(&proportions.value(index))?;
191
192 Ok(bin_ids.into_iter().zip(proportions).collect())
193}
194
195fn extract_bin_proportions(
197 batch: &RecordBatch,
198) -> Result<Vec<BTreeMap<usize, f64>>, DataFrameError> {
199 let bin_structs = get_bin_proportions_struct(batch)?.value(0);
200 let bin_structs = bin_structs
201 .as_any()
202 .downcast_ref::<StructArray>()
203 .ok_or_else(|| DataFrameError::DowncastError("Bin structs"))?;
204
205 let (bin_ids_field, proportions_field) = get_bin_fields(bin_structs)?;
206
207 let mut bin_proportions = Vec::with_capacity(bin_structs.len());
208 for i in 0..bin_structs.len() {
209 let bin_map = create_bin_map(bin_ids_field, proportions_field, i)?;
210 bin_proportions.push(bin_map);
211 }
212
213 Ok(bin_proportions)
214}
215
216fn get_overall_proportions_struct(batch: &RecordBatch) -> Result<&StructArray, DataFrameError> {
218 let overall_proportions_struct = batch
219 .column(3)
220 .as_any()
221 .downcast_ref::<StructArray>()
222 .ok_or_else(|| DataFrameError::DowncastError("overall proportion struct"))?;
223
224 Ok(overall_proportions_struct)
225}
226
227fn get_overall_fields(
228 overall_struct: &StructArray,
229) -> Result<(&ListArray, &ListArray), DataFrameError> {
230 let overall_bin_ids = overall_struct
231 .column_by_name("bin_id")
232 .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
233 .as_any()
234 .downcast_ref::<ListArray>()
235 .ok_or_else(|| DataFrameError::DowncastError("bin_id"))?;
236
237 let overall_proportions = overall_struct
238 .column_by_name("proportion")
239 .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
240 .as_any()
241 .downcast_ref::<ListArray>()
242 .ok_or_else(|| DataFrameError::DowncastError("proporition"))?;
243
244 Ok((overall_bin_ids, overall_proportions))
245}
246
247fn extract_overall_proportions(
248 batch: &RecordBatch,
249) -> Result<BTreeMap<usize, f64>, DataFrameError> {
250 let overall_struct = get_overall_proportions_struct(batch)?;
251 let (bin_ids, proportions) = get_overall_fields(overall_struct)?;
252
253 let bin_ids = get_bin_ids(&bin_ids.value(0))?;
254 let proportions = get_proportions(&proportions.value(0))?;
255
256 Ok(bin_ids.into_iter().zip(proportions).collect())
257}
258
259fn process_psi_record_batch(
268 batch: &RecordBatch,
269) -> Result<FeatureBinProportionResult, DataFrameError> {
270 Ok(FeatureBinProportionResult {
271 feature: ParquetHelper::extract_feature_array(batch)?
272 .value(0)
273 .to_string(),
274 created_at: ParquetHelper::extract_created_at(batch)?,
275 bin_proportions: extract_bin_proportions(batch)?,
276 overall_proportions: extract_overall_proportions(batch)?,
277 })
278}
279
280pub async fn dataframe_to_psi_drift_features(
288 df: DataFrame,
289) -> Result<Vec<FeatureBinProportionResult>, DataFrameError> {
290 let batches = df.collect().await?;
291
292 batches
293 .into_iter()
294 .map(|batch| process_psi_record_batch(&batch))
295 .collect()
296}