1use crate::sql::query::Queries;
2use crate::sql::utils::split_custom_interval;
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use scouter_dataframe::parquet::{dataframe_to_psi_drift_features, ParquetDataFrame};
6
7use crate::sql::error::SqlError;
8use itertools::multiunzip;
9use scouter_settings::ObjectStorageSettings;
10use scouter_types::psi::FeatureDistributions;
11use scouter_types::{
12 psi::{FeatureBinProportionResult, FeatureDistributionRow},
13 DriftRequest, PsiRecord, RecordType,
14};
15
16use sqlx::{postgres::PgQueryResult, Pool, Postgres};
17use std::collections::BTreeMap;
18use tracing::{debug, instrument};
19
20#[async_trait]
21pub trait PsiSqlLogic {
22 async fn insert_bin_counts_batch(
31 pool: &Pool<Postgres>,
32 records: &[PsiRecord],
33 entity_id: &i32,
34 ) -> Result<PgQueryResult, SqlError> {
35 if records.is_empty() {
36 return Err(SqlError::EmptyBatchError);
37 }
38
39 let query = Queries::InsertBinCountsBatch.get_query();
40
41 let (created_ats, entity_ids, features, bin_ids, bin_counts): (
42 Vec<DateTime<Utc>>,
43 Vec<i32>,
44 Vec<&str>,
45 Vec<i32>,
46 Vec<i32>,
47 ) = multiunzip(records.iter().map(|r| {
48 (
49 r.created_at,
50 entity_id,
51 r.feature.as_str(),
52 r.bin_id,
53 r.bin_count,
54 )
55 }));
56
57 sqlx::query(query)
58 .bind(created_ats)
59 .bind(entity_ids)
60 .bind(features)
61 .bind(bin_ids)
62 .bind(bin_counts)
63 .execute(pool)
64 .await
65 .map_err(SqlError::SqlxError)
66 }
67
68 async fn get_records(
78 pool: &Pool<Postgres>,
79 params: &DriftRequest,
80 start_dt: DateTime<Utc>,
81 end_dt: DateTime<Utc>,
82 entity_id: &i32,
83 ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
84 let minutes = end_dt.signed_duration_since(start_dt).num_minutes() as f64;
85 let bin = minutes / params.max_data_points as f64;
86 let query = Queries::GetBinnedPsiFeatureBins.get_query();
87
88 let binned: Vec<FeatureBinProportionResult> = sqlx::query_as(query)
89 .bind(bin)
90 .bind(start_dt)
91 .bind(end_dt)
92 .bind(entity_id)
93 .fetch_all(pool)
94 .await
95 .map_err(SqlError::SqlxError)?;
96
97 Ok(binned)
98 }
99
100 #[instrument(skip_all)]
112 async fn get_archived_records(
113 params: &DriftRequest,
114 begin: DateTime<Utc>,
115 end: DateTime<Utc>,
116 minutes: i32,
117 storage_settings: &ObjectStorageSettings,
118 entity_id: &i32,
119 ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
120 let path = format!("{}/psi", params.uid);
121 let bin = minutes as f64 / params.max_data_points as f64;
122
123 let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::Psi)?
124 .get_binned_metrics(&path, &bin, &begin, &end, entity_id)
125 .await?;
126
127 Ok(dataframe_to_psi_drift_features(archived_df).await?)
128 }
129
130 fn merge_feature_results(
132 results: Vec<FeatureBinProportionResult>,
133 feature_map: &mut BTreeMap<String, FeatureBinProportionResult>,
134 ) -> Result<(), SqlError> {
135 for result in results {
136 feature_map
137 .entry(result.feature.clone())
138 .and_modify(|existing| {
139 existing.created_at.extend(result.created_at.iter());
140 existing
141 .bin_proportions
142 .extend(result.bin_proportions.iter().cloned());
143
144 for (k, v) in result.overall_proportions.iter() {
145 existing
146 .overall_proportions
147 .entry(*k)
148 .and_modify(|existing_value| {
149 *existing_value = (*existing_value + *v) / 2.0;
150 })
151 .or_insert(*v);
152 }
153 })
154 .or_insert(result);
155 }
156
157 Ok(())
158 }
159
160 #[instrument(skip_all)]
172 async fn get_binned_psi_drift_records(
173 pool: &Pool<Postgres>,
174 params: &DriftRequest,
175 retention_period: &i32,
176 storage_settings: &ObjectStorageSettings,
177 entity_id: &i32,
178 ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
179 if !params.has_custom_interval() {
180 debug!("No custom interval provided, using default");
181 let (start_dt, end_dt) = params.time_interval.to_begin_end_times()?;
182 return Self::get_records(pool, params, start_dt, end_dt, entity_id).await;
183 }
184
185 debug!("Custom interval provided, using custom interval");
186 let interval = params.clone().to_custom_interval().unwrap();
187 let timestamps = split_custom_interval(interval.begin, interval.end, retention_period)?;
188 let mut feature_map = BTreeMap::new();
189
190 if let Some((active_begin, active_end)) = timestamps.active_range {
192 let current_results =
193 Self::get_records(pool, params, active_begin, active_end, entity_id).await?;
194 Self::merge_feature_results(current_results, &mut feature_map)?;
195 }
196
197 if let Some((archive_begin, archive_end)) = timestamps.archived_range {
199 if let Some(archived_minutes) = timestamps.archived_minutes {
200 let archived_results = Self::get_archived_records(
201 params,
202 archive_begin,
203 archive_end,
204 archived_minutes,
205 storage_settings,
206 entity_id,
207 )
208 .await?;
209
210 Self::merge_feature_results(archived_results, &mut feature_map)?;
211 }
212 }
213 Ok(feature_map.into_values().collect())
214 }
215
216 async fn get_feature_distributions(
217 pool: &Pool<Postgres>,
218 limit_datetime: &DateTime<Utc>,
219 features_to_monitor: &[String],
220 entity_id: &i32,
221 ) -> Result<FeatureDistributions, SqlError> {
222 let query = Queries::GetFeatureBinProportions.get_query();
223
224 let rows: Vec<FeatureDistributionRow> = sqlx::query_as(query)
225 .bind(entity_id)
226 .bind(limit_datetime)
227 .bind(features_to_monitor)
228 .fetch_all(pool)
229 .await
230 .map_err(SqlError::SqlxError)?;
231
232 Ok(FeatureDistributions::from_rows(rows))
233 }
234}