1use crate::sql::query::Queries;
2use crate::sql::schema::FeatureBinProportionResultWrapper;
3use crate::sql::schema::FeatureDistributionWrapper;
4use crate::sql::utils::split_custom_interval;
5use async_trait::async_trait;
6use chrono::{DateTime, Utc};
7use scouter_dataframe::parquet::{dataframe_to_psi_drift_features, ParquetDataFrame};
8
9use crate::sql::error::SqlError;
10use itertools::multiunzip;
11use scouter_settings::ObjectStorageSettings;
12use scouter_types::psi::FeatureDistributions;
13use scouter_types::{
14 psi::FeatureBinProportionResult, DriftRequest, PsiServerRecord, RecordType, ServiceInfo,
15};
16use sqlx::{postgres::PgQueryResult, Pool, Postgres};
17use std::collections::BTreeMap;
18use tracing::{debug, instrument};
19
20#[async_trait]
21pub trait PsiSqlLogic {
22 async fn insert_bin_counts_batch(
31 pool: &Pool<Postgres>,
32 records: &[PsiServerRecord],
33 ) -> Result<PgQueryResult, SqlError> {
34 if records.is_empty() {
35 return Err(SqlError::EmptyBatchError);
36 }
37
38 let query = Queries::InsertBinCountsBatch.get_query();
39
40 let (created_ats, names, spaces, versions, features, bin_ids, bin_counts): (
41 Vec<DateTime<Utc>>,
42 Vec<&str>,
43 Vec<&str>,
44 Vec<&str>,
45 Vec<&str>,
46 Vec<i64>,
47 Vec<i64>,
48 ) = multiunzip(records.iter().map(|r| {
49 (
50 r.created_at,
51 r.name.as_str(),
52 r.space.as_str(),
53 r.version.as_str(),
54 r.feature.as_str(),
55 r.bin_id as i64,
56 r.bin_count as i64,
57 )
58 }));
59
60 sqlx::query(&query.sql)
61 .bind(created_ats)
62 .bind(names)
63 .bind(spaces)
64 .bind(versions)
65 .bind(features)
66 .bind(bin_ids)
67 .bind(bin_counts)
68 .execute(pool)
69 .await
70 .map_err(SqlError::SqlxError)
71 }
72
73 async fn get_records(
83 pool: &Pool<Postgres>,
84 params: &DriftRequest,
85 minutes: i32,
86 ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
87 let bin = minutes as f64 / params.max_data_points as f64;
88 let query = Queries::GetBinnedPsiFeatureBins.get_query();
89
90 let binned: Vec<FeatureBinProportionResult> = sqlx::query_as(&query.sql)
91 .bind(bin)
92 .bind(minutes)
93 .bind(¶ms.name)
94 .bind(¶ms.space)
95 .bind(¶ms.version)
96 .fetch_all(pool)
97 .await
98 .map_err(SqlError::SqlxError)?
99 .into_iter()
100 .map(|wrapper: FeatureBinProportionResultWrapper| wrapper.0)
101 .collect();
102
103 Ok(binned)
104 }
105
106 #[instrument(skip_all)]
118 async fn get_archived_records(
119 params: &DriftRequest,
120 begin: DateTime<Utc>,
121 end: DateTime<Utc>,
122 minutes: i32,
123 storage_settings: &ObjectStorageSettings,
124 ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
125 let path = format!("{}/{}/{}/psi", params.space, params.name, params.version);
126 let bin = minutes as f64 / params.max_data_points as f64;
127
128 let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::Psi)?
129 .get_binned_metrics(
130 &path,
131 &bin,
132 &begin,
133 &end,
134 ¶ms.space,
135 ¶ms.name,
136 ¶ms.version,
137 )
138 .await?;
139
140 Ok(dataframe_to_psi_drift_features(archived_df).await?)
141 }
142
143 fn merge_feature_results(
145 results: Vec<FeatureBinProportionResult>,
146 feature_map: &mut BTreeMap<String, FeatureBinProportionResult>,
147 ) -> Result<(), SqlError> {
148 for result in results {
149 feature_map
150 .entry(result.feature.clone())
151 .and_modify(|existing| {
152 existing.created_at.extend(result.created_at.iter());
153 existing
154 .bin_proportions
155 .extend(result.bin_proportions.iter().cloned());
156
157 for (k, v) in result.overall_proportions.iter() {
158 existing
159 .overall_proportions
160 .entry(*k)
161 .and_modify(|existing_value| {
162 *existing_value = (*existing_value + *v) / 2.0;
163 })
164 .or_insert(*v);
165 }
166 })
167 .or_insert(result);
168 }
169
170 Ok(())
171 }
172
173 #[instrument(skip_all)]
185 async fn get_binned_psi_drift_records(
186 pool: &Pool<Postgres>,
187 params: &DriftRequest,
188 retention_period: &i32,
189 storage_settings: &ObjectStorageSettings,
190 ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
191 if !params.has_custom_interval() {
192 debug!("No custom interval provided, using default");
193 let minutes = params.time_interval.to_minutes();
194 return Self::get_records(pool, params, minutes).await;
195 }
196
197 debug!("Custom interval provided, using custom interval");
198 let interval = params.clone().to_custom_interval().unwrap();
199 let timestamps = split_custom_interval(interval.start, interval.end, retention_period)?;
200 let mut feature_map = BTreeMap::new();
201
202 if let Some(minutes) = timestamps.current_minutes {
204 let current_results = Self::get_records(pool, params, minutes).await?;
205 Self::merge_feature_results(current_results, &mut feature_map)?;
206 }
207
208 if let Some((archive_begin, archive_end)) = timestamps.archived_range {
210 if let Some(archived_minutes) = timestamps.archived_minutes {
211 let archived_results = Self::get_archived_records(
212 params,
213 archive_begin,
214 archive_end,
215 archived_minutes,
216 storage_settings,
217 )
218 .await?;
219
220 Self::merge_feature_results(archived_results, &mut feature_map)?;
221 }
222 }
223 Ok(feature_map.into_values().collect())
224 }
225
226 async fn get_feature_distributions(
227 pool: &Pool<Postgres>,
228 service_info: &ServiceInfo,
229 limit_datetime: &DateTime<Utc>,
230 features_to_monitor: &[String],
231 ) -> Result<FeatureDistributions, SqlError> {
232 let query = Queries::GetFeatureBinProportions.get_query();
233
234 let feature_distributions: Vec<FeatureDistributionWrapper> = sqlx::query_as(&query.sql)
235 .bind(&service_info.name)
236 .bind(&service_info.space)
237 .bind(&service_info.version)
238 .bind(limit_datetime)
239 .bind(features_to_monitor)
240 .fetch_all(pool)
241 .await
242 .map_err(SqlError::SqlxError)?;
243
244 let distributions = feature_distributions
245 .into_iter()
246 .map(|wrapper| (wrapper.0, wrapper.1))
247 .collect();
248
249 Ok(FeatureDistributions { distributions })
250 }
251}