scouter_sql/sql/traits/
psi.rs

1use crate::sql::query::Queries;
2use crate::sql::schema::FeatureBinProportionResultWrapper;
3use crate::sql::schema::FeatureDistributionWrapper;
4use crate::sql::utils::split_custom_interval;
5use async_trait::async_trait;
6use chrono::{DateTime, Utc};
7use scouter_dataframe::parquet::{dataframe_to_psi_drift_features, ParquetDataFrame};
8
9use crate::sql::error::SqlError;
10use itertools::multiunzip;
11use scouter_settings::ObjectStorageSettings;
12use scouter_types::psi::FeatureDistributions;
13use scouter_types::{
14    psi::FeatureBinProportionResult, DriftRequest, PsiServerRecord, RecordType, ServiceInfo,
15};
16use sqlx::{postgres::PgQueryResult, Pool, Postgres};
17use std::collections::BTreeMap;
18use tracing::{debug, instrument};
19
20#[async_trait]
21pub trait PsiSqlLogic {
22    /// Inserts multiple PSI bin counts into the database in a batch.
23    ///
24    /// # Arguments
25    /// * `pool` - The database connection pool
26    /// * `records` - The PSI server records to insert
27    ///
28    /// # Returns
29    /// * A result containing the query result or an error
30    async fn insert_bin_counts_batch(
31        pool: &Pool<Postgres>,
32        records: &[PsiServerRecord],
33    ) -> Result<PgQueryResult, SqlError> {
34        if records.is_empty() {
35            return Err(SqlError::EmptyBatchError);
36        }
37
38        let query = Queries::InsertBinCountsBatch.get_query();
39
40        let (created_ats, names, spaces, versions, features, bin_ids, bin_counts): (
41            Vec<DateTime<Utc>>,
42            Vec<&str>,
43            Vec<&str>,
44            Vec<&str>,
45            Vec<&str>,
46            Vec<i64>,
47            Vec<i64>,
48        ) = multiunzip(records.iter().map(|r| {
49            (
50                r.created_at,
51                r.name.as_str(),
52                r.space.as_str(),
53                r.version.as_str(),
54                r.feature.as_str(),
55                r.bin_id as i64,
56                r.bin_count as i64,
57            )
58        }));
59
60        sqlx::query(&query.sql)
61            .bind(created_ats)
62            .bind(names)
63            .bind(spaces)
64            .bind(versions)
65            .bind(features)
66            .bind(bin_ids)
67            .bind(bin_counts)
68            .execute(pool)
69            .await
70            .map_err(SqlError::SqlxError)
71    }
72
73    /// Queries the database for PSI drift records based on a time window
74    /// and aggregation.
75    ///
76    /// # Arguments
77    /// * `pool` - The database connection pool
78    /// * `params` - The drift request parameters
79    ///
80    /// # Returns
81    /// * A vector of drift records
82    async fn get_records(
83        pool: &Pool<Postgres>,
84        params: &DriftRequest,
85        minutes: i32,
86    ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
87        let bin = minutes as f64 / params.max_data_points as f64;
88        let query = Queries::GetBinnedPsiFeatureBins.get_query();
89
90        let binned: Vec<FeatureBinProportionResult> = sqlx::query_as(&query.sql)
91            .bind(bin)
92            .bind(minutes)
93            .bind(&params.name)
94            .bind(&params.space)
95            .bind(&params.version)
96            .fetch_all(pool)
97            .await
98            .map_err(SqlError::SqlxError)?
99            .into_iter()
100            .map(|wrapper: FeatureBinProportionResultWrapper| wrapper.0)
101            .collect();
102
103        Ok(binned)
104    }
105
106    /// DataFusion implementation for getting PSI drift records from archived data.
107    ///
108    /// # Arguments
109    /// * `params` - The drift request parameters
110    /// * `begin` - The start time of the time window
111    /// * `end` - The end time of the time window
112    /// * `minutes` - The number of minutes to bin the data
113    /// * `storage_settings` - The object storage settings
114    ///
115    /// # Returns
116    /// * A vector of drift records
117    #[instrument(skip_all)]
118    async fn get_archived_records(
119        params: &DriftRequest,
120        begin: DateTime<Utc>,
121        end: DateTime<Utc>,
122        minutes: i32,
123        storage_settings: &ObjectStorageSettings,
124    ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
125        let path = format!("{}/{}/{}/psi", params.space, params.name, params.version);
126        let bin = minutes as f64 / params.max_data_points as f64;
127
128        let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::Psi)?
129            .get_binned_metrics(
130                &path,
131                &bin,
132                &begin,
133                &end,
134                &params.space,
135                &params.name,
136                &params.version,
137            )
138            .await?;
139
140        Ok(dataframe_to_psi_drift_features(archived_df).await?)
141    }
142
143    /// Helper for merging current and archived binned PSI drift records.
144    fn merge_feature_results(
145        results: Vec<FeatureBinProportionResult>,
146        feature_map: &mut BTreeMap<String, FeatureBinProportionResult>,
147    ) -> Result<(), SqlError> {
148        for result in results {
149            feature_map
150                .entry(result.feature.clone())
151                .and_modify(|existing| {
152                    existing.created_at.extend(result.created_at.iter());
153                    existing
154                        .bin_proportions
155                        .extend(result.bin_proportions.iter().cloned());
156
157                    for (k, v) in result.overall_proportions.iter() {
158                        existing
159                            .overall_proportions
160                            .entry(*k)
161                            .and_modify(|existing_value| {
162                                *existing_value = (*existing_value + *v) / 2.0;
163                            })
164                            .or_insert(*v);
165                    }
166                })
167                .or_insert(result);
168        }
169
170        Ok(())
171    }
172
173    // Queries the database for drift records based on a time window and aggregation.
174    // Based on the time window provided, a query or queries will be run against the short-term and
175    // archived data.
176    //
177    // # Arguments
178    //
179    // * `name` - The name of the service to query drift records for
180    // * `params` - The drift request parameters
181    // # Returns
182    //
183    // * A vector of drift records
184    #[instrument(skip_all)]
185    async fn get_binned_psi_drift_records(
186        pool: &Pool<Postgres>,
187        params: &DriftRequest,
188        retention_period: &i32,
189        storage_settings: &ObjectStorageSettings,
190    ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
191        if !params.has_custom_interval() {
192            debug!("No custom interval provided, using default");
193            let minutes = params.time_interval.to_minutes();
194            return Self::get_records(pool, params, minutes).await;
195        }
196
197        debug!("Custom interval provided, using custom interval");
198        let interval = params.clone().to_custom_interval().unwrap();
199        let timestamps = split_custom_interval(interval.start, interval.end, retention_period)?;
200        let mut feature_map = BTreeMap::new();
201
202        // Get current records if available
203        if let Some(minutes) = timestamps.current_minutes {
204            let current_results = Self::get_records(pool, params, minutes).await?;
205            Self::merge_feature_results(current_results, &mut feature_map)?;
206        }
207
208        // Get archived records if available
209        if let Some((archive_begin, archive_end)) = timestamps.archived_range {
210            if let Some(archived_minutes) = timestamps.archived_minutes {
211                let archived_results = Self::get_archived_records(
212                    params,
213                    archive_begin,
214                    archive_end,
215                    archived_minutes,
216                    storage_settings,
217                )
218                .await?;
219
220                Self::merge_feature_results(archived_results, &mut feature_map)?;
221            }
222        }
223        Ok(feature_map.into_values().collect())
224    }
225
226    async fn get_feature_distributions(
227        pool: &Pool<Postgres>,
228        service_info: &ServiceInfo,
229        limit_datetime: &DateTime<Utc>,
230        features_to_monitor: &[String],
231    ) -> Result<FeatureDistributions, SqlError> {
232        let query = Queries::GetFeatureBinProportions.get_query();
233
234        let feature_distributions: Vec<FeatureDistributionWrapper> = sqlx::query_as(&query.sql)
235            .bind(&service_info.name)
236            .bind(&service_info.space)
237            .bind(&service_info.version)
238            .bind(limit_datetime)
239            .bind(features_to_monitor)
240            .fetch_all(pool)
241            .await
242            .map_err(SqlError::SqlxError)?;
243
244        let distributions = feature_distributions
245            .into_iter()
246            .map(|wrapper| (wrapper.0, wrapper.1))
247            .collect();
248
249        Ok(FeatureDistributions { distributions })
250    }
251}