scouter_dataframe/parquet/
psi.rs1use crate::error::DataFrameError;
2use crate::parquet::traits::ParquetFrame;
3use crate::sql::helper::get_binned_psi_drift_records_query;
4use crate::storage::ObjectStore;
5use arrow::array::AsArray;
6use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
7use arrow_array::array::{
8 ListArray, StringArray, StringViewArray, StructArray, TimestampNanosecondArray, UInt64Array,
9};
10use arrow_array::types::{Float32Type, UInt64Type};
11use arrow_array::Array;
12use arrow_array::RecordBatch;
13use async_trait::async_trait;
14use chrono::{DateTime, TimeZone, Utc};
15use datafusion::dataframe::DataFrame;
16use datafusion::prelude::SessionContext;
17use scouter_settings::ObjectStorageSettings;
18use scouter_types::{
19 psi::FeatureBinProportionResult, PsiServerRecord, ServerRecords, StorageType, ToDriftRecords,
20};
21use std::collections::BTreeMap;
22use std::sync::Arc;
23
24use super::types::BinnedTableName;
25pub struct PsiDataFrame {
26 schema: Arc<Schema>,
27 pub object_store: ObjectStore,
28}
29
30#[async_trait]
31impl ParquetFrame for PsiDataFrame {
32 fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
33 PsiDataFrame::new(storage_settings)
34 }
35
36 async fn get_dataframe(&self, records: ServerRecords) -> Result<DataFrame, DataFrameError> {
37 let records = records.to_psi_drift_records()?;
38 let batch = self.build_batch(records)?;
39
40 let ctx = self.object_store.get_session()?;
41
42 let df = ctx.read_batches(vec![batch])?;
43 Ok(df)
44 }
45
46 fn storage_root(&self) -> String {
47 self.object_store.storage_settings.canonicalized_path()
48 }
49
50 fn storage_type(&self) -> StorageType {
51 self.object_store.storage_settings.storage_type.clone()
52 }
53
54 fn get_session_context(&self) -> Result<SessionContext, DataFrameError> {
55 Ok(self.object_store.get_session()?)
56 }
57
58 fn get_binned_sql(
59 &self,
60 bin: &f64,
61 start_time: &DateTime<Utc>,
62 end_time: &DateTime<Utc>,
63 space: &str,
64 name: &str,
65 version: &str,
66 ) -> String {
67 get_binned_psi_drift_records_query(bin, start_time, end_time, space, name, version)
68 }
69
70 fn table_name(&self) -> String {
71 BinnedTableName::Psi.to_string()
72 }
73}
74
75impl PsiDataFrame {
76 pub fn new(storage_settings: &ObjectStorageSettings) -> Result<Self, DataFrameError> {
77 let schema = Arc::new(Schema::new(vec![
78 Field::new(
79 "created_at",
80 DataType::Timestamp(TimeUnit::Nanosecond, None),
81 false,
82 ),
83 Field::new("space", DataType::Utf8, false),
84 Field::new("name", DataType::Utf8, false),
85 Field::new("version", DataType::Utf8, false),
86 Field::new("feature", DataType::Utf8, false),
87 Field::new("bin_id", DataType::UInt64, false),
88 Field::new("bin_count", DataType::UInt64, false),
89 ]));
90
91 let object_store = ObjectStore::new(storage_settings)?;
92
93 Ok(PsiDataFrame {
94 schema,
95 object_store,
96 })
97 }
98
99 fn build_batch(&self, records: Vec<PsiServerRecord>) -> Result<RecordBatch, DataFrameError> {
101 let created_at_array = TimestampNanosecondArray::from_iter_values(
102 records
103 .iter()
104 .map(|r| r.created_at.timestamp_nanos_opt().unwrap_or_default()),
105 );
106
107 let space_array = StringArray::from_iter_values(records.iter().map(|r| r.space.as_str()));
108 let name_array = StringArray::from_iter_values(records.iter().map(|r| r.name.as_str()));
109 let version_array =
110 StringArray::from_iter_values(records.iter().map(|r| r.version.as_str()));
111 let feature_array =
112 StringArray::from_iter_values(records.iter().map(|r| r.feature.as_str()));
113
114 let bin_id_array = UInt64Array::from_iter_values(records.iter().map(|r| r.bin_id as u64));
115 let bin_count_array =
116 UInt64Array::from_iter_values(records.iter().map(|r| r.bin_count as u64));
117
118 let batch = RecordBatch::try_new(
119 self.schema.clone(),
120 vec![
121 Arc::new(created_at_array),
122 Arc::new(space_array),
123 Arc::new(name_array),
124 Arc::new(version_array),
125 Arc::new(feature_array),
126 Arc::new(bin_id_array),
127 Arc::new(bin_count_array),
128 ],
129 )?;
130
131 Ok(batch)
132 }
133}
134
135fn extract_feature(batch: &RecordBatch) -> Result<String, DataFrameError> {
137 let feature_array = batch
138 .column(0)
139 .as_any()
140 .downcast_ref::<StringViewArray>()
141 .ok_or_else(|| DataFrameError::GetColumnError("feature"))?;
142 Ok(feature_array.value(0).to_string())
143}
144
145fn extract_created_at(batch: &RecordBatch) -> Result<Vec<DateTime<Utc>>, DataFrameError> {
147 let created_at_list = batch
148 .column(1)
149 .as_any()
150 .downcast_ref::<ListArray>()
151 .ok_or_else(|| DataFrameError::GetColumnError("created_at"))?;
152
153 let created_at_array = created_at_list.value(0);
154 Ok(created_at_array
155 .as_primitive::<arrow::datatypes::TimestampNanosecondType>()
156 .iter()
157 .filter_map(|ts| ts.map(|t| Utc.timestamp_nanos(t)))
158 .collect())
159}
160
161fn get_bin_proportions_struct(batch: &RecordBatch) -> Result<&ListArray, DataFrameError> {
163 batch
164 .column(2)
165 .as_any()
166 .downcast_ref::<ListArray>()
167 .ok_or_else(|| DataFrameError::GetColumnError("bin_proportions"))
168}
169
170fn get_bin_fields(structs: &StructArray) -> Result<(&ListArray, &ListArray), DataFrameError> {
172 let bin_ids = structs
173 .column_by_name("bin_id")
174 .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
175 .as_any()
176 .downcast_ref::<ListArray>()
177 .ok_or_else(|| DataFrameError::GetColumnError("bin_id"))?;
178
179 let proportions = structs
180 .column_by_name("proportion")
181 .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
182 .as_any()
183 .downcast_ref::<ListArray>()
184 .ok_or_else(|| DataFrameError::GetColumnError("proportion"))?;
185
186 Ok((bin_ids, proportions))
187}
188
189fn get_bin_ids(array: &dyn Array) -> Result<Vec<usize>, DataFrameError> {
191 Ok(array
192 .as_primitive::<UInt64Type>()
193 .iter()
194 .filter_map(|id| id.map(|i| i as usize))
195 .collect())
196}
197
198fn get_proportions(array: &dyn Array) -> Result<Vec<f64>, DataFrameError> {
201 Ok(array
202 .as_primitive::<Float32Type>()
203 .iter()
204 .filter_map(|p| p.map(|v| v as f64))
205 .collect())
206}
207
208fn create_bin_map(
210 bin_ids: &ListArray,
211 proportions: &ListArray,
212 index: usize,
213) -> Result<BTreeMap<usize, f64>, DataFrameError> {
214 let bin_ids = get_bin_ids(&bin_ids.value(index))?;
215 let proportions = get_proportions(&proportions.value(index))?;
216
217 Ok(bin_ids.into_iter().zip(proportions).collect())
218}
219
220fn extract_bin_proportions(
222 batch: &RecordBatch,
223) -> Result<Vec<BTreeMap<usize, f64>>, DataFrameError> {
224 let bin_structs = get_bin_proportions_struct(batch)?.value(0);
225 let bin_structs = bin_structs
226 .as_any()
227 .downcast_ref::<StructArray>()
228 .ok_or_else(|| DataFrameError::DowncastError("Bin structs"))?;
229
230 let (bin_ids_field, proportions_field) = get_bin_fields(bin_structs)?;
231
232 let mut bin_proportions = Vec::with_capacity(bin_structs.len());
233 for i in 0..bin_structs.len() {
234 let bin_map = create_bin_map(bin_ids_field, proportions_field, i)?;
235 bin_proportions.push(bin_map);
236 }
237
238 Ok(bin_proportions)
239}
240
241fn get_overall_proportions_struct(batch: &RecordBatch) -> Result<&StructArray, DataFrameError> {
243 let overall_proportions_struct = batch
244 .column(3)
245 .as_any()
246 .downcast_ref::<StructArray>()
247 .ok_or_else(|| DataFrameError::DowncastError("overall proportion struct"))?;
248
249 Ok(overall_proportions_struct)
250}
251
252fn get_overall_fields(
253 overall_struct: &StructArray,
254) -> Result<(&ListArray, &ListArray), DataFrameError> {
255 let overall_bin_ids = overall_struct
256 .column_by_name("bin_id")
257 .ok_or_else(|| DataFrameError::MissingFieldError("bin_id"))?
258 .as_any()
259 .downcast_ref::<ListArray>()
260 .ok_or_else(|| DataFrameError::DowncastError("bin_id"))?;
261
262 let overall_proportions = overall_struct
263 .column_by_name("proportion")
264 .ok_or_else(|| DataFrameError::MissingFieldError("proportion"))?
265 .as_any()
266 .downcast_ref::<ListArray>()
267 .ok_or_else(|| DataFrameError::DowncastError("proporition"))?;
268
269 Ok((overall_bin_ids, overall_proportions))
270}
271
272fn extract_overall_proportions(
273 batch: &RecordBatch,
274) -> Result<BTreeMap<usize, f64>, DataFrameError> {
275 let overall_struct = get_overall_proportions_struct(batch)?;
276 let (bin_ids, proportions) = get_overall_fields(overall_struct)?;
277
278 let bin_ids = get_bin_ids(&bin_ids.value(0))?;
279 let proportions = get_proportions(&proportions.value(0))?;
280
281 Ok(bin_ids.into_iter().zip(proportions).collect())
282}
283
284fn process_psi_record_batch(
293 batch: &RecordBatch,
294) -> Result<FeatureBinProportionResult, DataFrameError> {
295 Ok(FeatureBinProportionResult {
296 feature: extract_feature(batch)?,
297 created_at: extract_created_at(batch)?,
298 bin_proportions: extract_bin_proportions(batch)?,
299 overall_proportions: extract_overall_proportions(batch)?,
300 })
301}
302
303pub async fn dataframe_to_psi_drift_features(
311 df: DataFrame,
312) -> Result<Vec<FeatureBinProportionResult>, DataFrameError> {
313 let batches = df.collect().await?;
314
315 batches
316 .into_iter()
317 .map(|batch| process_psi_record_batch(&batch))
318 .collect()
319}