scouter_sql/sql/traits/
psi.rs

1use crate::sql::query::Queries;
2use crate::sql::schema::FeatureBinProportionResultWrapper;
3use crate::sql::schema::FeatureDistributionWrapper;
4use crate::sql::utils::split_custom_interval;
5use async_trait::async_trait;
6use chrono::{DateTime, Utc};
7use scouter_dataframe::parquet::{dataframe_to_psi_drift_features, ParquetDataFrame};
8
9use crate::sql::error::SqlError;
10use scouter_settings::ObjectStorageSettings;
11use scouter_types::psi::FeatureDistributions;
12use scouter_types::{
13    psi::FeatureBinProportionResult, DriftRequest, PsiServerRecord, RecordType, ServiceInfo,
14};
15use sqlx::{postgres::PgQueryResult, Pool, Postgres};
16use std::collections::BTreeMap;
17use tracing::{debug, instrument};
18
19#[async_trait]
20pub trait PsiSqlLogic {
21    /// Inserts a PSI bin count into the database.
22    ///
23    /// # Arguments
24    /// * `pool` - The database connection pool
25    /// * `record` - The PSI server record to insert
26    ///
27    /// # Returns
28    /// * A result containing the query result or an error
29    async fn insert_bin_counts(
30        pool: &Pool<Postgres>,
31        record: &PsiServerRecord,
32    ) -> Result<PgQueryResult, SqlError> {
33        let query = Queries::InsertBinCounts.get_query();
34
35        sqlx::query(&query.sql)
36            .bind(record.created_at)
37            .bind(&record.name)
38            .bind(&record.space)
39            .bind(&record.version)
40            .bind(&record.feature)
41            .bind(record.bin_id as i64)
42            .bind(record.bin_count as i64)
43            .execute(pool)
44            .await
45            .map_err(SqlError::SqlxError)
46    }
47
48    /// Queries the database for PSI drift records based on a time window
49    /// and aggregation.
50    ///
51    /// # Arguments
52    /// * `pool` - The database connection pool
53    /// * `params` - The drift request parameters
54    ///
55    /// # Returns
56    /// * A vector of drift records
57    async fn get_records(
58        pool: &Pool<Postgres>,
59        params: &DriftRequest,
60        minutes: i32,
61    ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
62        let bin = minutes as f64 / params.max_data_points as f64;
63        let query = Queries::GetBinnedPsiFeatureBins.get_query();
64
65        let binned: Vec<FeatureBinProportionResult> = sqlx::query_as(&query.sql)
66            .bind(bin)
67            .bind(minutes)
68            .bind(&params.name)
69            .bind(&params.space)
70            .bind(&params.version)
71            .fetch_all(pool)
72            .await
73            .map_err(SqlError::SqlxError)?
74            .into_iter()
75            .map(|wrapper: FeatureBinProportionResultWrapper| wrapper.0)
76            .collect();
77
78        Ok(binned)
79    }
80
81    /// DataFusion implementation for getting PSI drift records from archived data.
82    ///
83    /// # Arguments
84    /// * `params` - The drift request parameters
85    /// * `begin` - The start time of the time window
86    /// * `end` - The end time of the time window
87    /// * `minutes` - The number of minutes to bin the data
88    /// * `storage_settings` - The object storage settings
89    ///
90    /// # Returns
91    /// * A vector of drift records
92    #[instrument(skip_all)]
93    async fn get_archived_records(
94        params: &DriftRequest,
95        begin: DateTime<Utc>,
96        end: DateTime<Utc>,
97        minutes: i32,
98        storage_settings: &ObjectStorageSettings,
99    ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
100        let path = format!("{}/{}/{}/psi", params.space, params.name, params.version);
101        let bin = minutes as f64 / params.max_data_points as f64;
102
103        let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::Psi)?
104            .get_binned_metrics(
105                &path,
106                &bin,
107                &begin,
108                &end,
109                &params.space,
110                &params.name,
111                &params.version,
112            )
113            .await?;
114
115        Ok(dataframe_to_psi_drift_features(archived_df).await?)
116    }
117
118    /// Helper for merging current and archived binned PSI drift records.
119    fn merge_feature_results(
120        results: Vec<FeatureBinProportionResult>,
121        feature_map: &mut BTreeMap<String, FeatureBinProportionResult>,
122    ) -> Result<(), SqlError> {
123        for result in results {
124            feature_map
125                .entry(result.feature.clone())
126                .and_modify(|existing| {
127                    existing.created_at.extend(result.created_at.iter());
128                    existing
129                        .bin_proportions
130                        .extend(result.bin_proportions.iter().cloned());
131
132                    for (k, v) in result.overall_proportions.iter() {
133                        existing
134                            .overall_proportions
135                            .entry(*k)
136                            .and_modify(|existing_value| {
137                                *existing_value = (*existing_value + *v) / 2.0;
138                            })
139                            .or_insert(*v);
140                    }
141                })
142                .or_insert(result);
143        }
144
145        Ok(())
146    }
147
148    // Queries the database for drift records based on a time window and aggregation.
149    // Based on the time window provided, a query or queries will be run against the short-term and
150    // archived data.
151    //
152    // # Arguments
153    //
154    // * `name` - The name of the service to query drift records for
155    // * `params` - The drift request parameters
156    // # Returns
157    //
158    // * A vector of drift records
159    #[instrument(skip_all)]
160    async fn get_binned_psi_drift_records(
161        pool: &Pool<Postgres>,
162        params: &DriftRequest,
163        retention_period: &i32,
164        storage_settings: &ObjectStorageSettings,
165    ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
166        if !params.has_custom_interval() {
167            debug!("No custom interval provided, using default");
168            let minutes = params.time_interval.to_minutes();
169            return Self::get_records(pool, params, minutes).await;
170        }
171
172        debug!("Custom interval provided, using custom interval");
173        let interval = params.clone().to_custom_interval().unwrap();
174        let timestamps = split_custom_interval(interval.start, interval.end, retention_period)?;
175        let mut feature_map = BTreeMap::new();
176
177        // Get current records if available
178        if let Some(minutes) = timestamps.current_minutes {
179            let current_results = Self::get_records(pool, params, minutes).await?;
180            Self::merge_feature_results(current_results, &mut feature_map)?;
181        }
182
183        // Get archived records if available
184        if let Some((archive_begin, archive_end)) = timestamps.archived_range {
185            if let Some(archived_minutes) = timestamps.archived_minutes {
186                let archived_results = Self::get_archived_records(
187                    params,
188                    archive_begin,
189                    archive_end,
190                    archived_minutes,
191                    storage_settings,
192                )
193                .await?;
194
195                Self::merge_feature_results(archived_results, &mut feature_map)?;
196            }
197        }
198        Ok(feature_map.into_values().collect())
199    }
200
201    async fn get_feature_distributions(
202        pool: &Pool<Postgres>,
203        service_info: &ServiceInfo,
204        limit_datetime: &DateTime<Utc>,
205        features_to_monitor: &[String],
206    ) -> Result<FeatureDistributions, SqlError> {
207        let query = Queries::GetFeatureBinProportions.get_query();
208
209        let feature_distributions: Vec<FeatureDistributionWrapper> = sqlx::query_as(&query.sql)
210            .bind(&service_info.name)
211            .bind(&service_info.space)
212            .bind(&service_info.version)
213            .bind(limit_datetime)
214            .bind(features_to_monitor)
215            .fetch_all(pool)
216            .await
217            .map_err(SqlError::SqlxError)?;
218
219        let distributions = feature_distributions
220            .into_iter()
221            .map(|wrapper| (wrapper.0, wrapper.1))
222            .collect();
223
224        Ok(FeatureDistributions { distributions })
225    }
226}