Skip to main content

scouter_sql/sql/traits/
psi.rs

1use crate::sql::query::Queries;
2use crate::sql::utils::split_custom_interval;
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use scouter_dataframe::parquet::{dataframe_to_psi_drift_features, ParquetDataFrame};
6
7use crate::sql::error::SqlError;
8use itertools::multiunzip;
9use scouter_settings::ObjectStorageSettings;
10use scouter_types::psi::FeatureDistributions;
11use scouter_types::{
12    psi::{FeatureBinProportionResult, FeatureDistributionRow},
13    DriftRequest, PsiRecord, RecordType,
14};
15
16use sqlx::{postgres::PgQueryResult, Pool, Postgres};
17use std::collections::BTreeMap;
18use tracing::{debug, instrument};
19
20#[async_trait]
21pub trait PsiSqlLogic {
22    /// Inserts multiple PSI bin counts into the database in a batch.
23    ///
24    /// # Arguments
25    /// * `pool` - The database connection pool
26    /// * `records` - The PSI server records to insert
27    ///
28    /// # Returns
29    /// * A result containing the query result or an error
30    async fn insert_bin_counts_batch(
31        pool: &Pool<Postgres>,
32        records: &[PsiRecord],
33        entity_id: &i32,
34    ) -> Result<PgQueryResult, SqlError> {
35        if records.is_empty() {
36            return Err(SqlError::EmptyBatchError);
37        }
38
39        let query = Queries::InsertBinCountsBatch.get_query();
40
41        let (created_ats, entity_ids, features, bin_ids, bin_counts): (
42            Vec<DateTime<Utc>>,
43            Vec<i32>,
44            Vec<&str>,
45            Vec<i32>,
46            Vec<i32>,
47        ) = multiunzip(records.iter().map(|r| {
48            (
49                r.created_at,
50                entity_id,
51                r.feature.as_str(),
52                r.bin_id,
53                r.bin_count,
54            )
55        }));
56
57        sqlx::query(query)
58            .bind(created_ats)
59            .bind(entity_ids)
60            .bind(features)
61            .bind(bin_ids)
62            .bind(bin_counts)
63            .execute(pool)
64            .await
65            .map_err(SqlError::SqlxError)
66    }
67
68    /// Queries the database for PSI drift records based on a time window
69    /// and aggregation.
70    ///
71    /// # Arguments
72    /// * `pool` - The database connection pool
73    /// * `params` - The drift request parameters
74    ///
75    /// # Returns
76    /// * A vector of drift records
77    async fn get_records(
78        pool: &Pool<Postgres>,
79        params: &DriftRequest,
80        start_dt: DateTime<Utc>,
81        end_dt: DateTime<Utc>,
82        entity_id: &i32,
83    ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
84        let minutes = end_dt.signed_duration_since(start_dt).num_minutes() as f64;
85        let bin = minutes / params.max_data_points as f64;
86        let query = Queries::GetBinnedPsiFeatureBins.get_query();
87
88        let binned: Vec<FeatureBinProportionResult> = sqlx::query_as(query)
89            .bind(bin)
90            .bind(start_dt)
91            .bind(end_dt)
92            .bind(entity_id)
93            .fetch_all(pool)
94            .await
95            .map_err(SqlError::SqlxError)?;
96
97        Ok(binned)
98    }
99
100    /// DataFusion implementation for getting PSI drift records from archived data.
101    ///
102    /// # Arguments
103    /// * `params` - The drift request parameters
104    /// * `begin` - The start time of the time window
105    /// * `end` - The end time of the time window
106    /// * `minutes` - The number of minutes to bin the data
107    /// * `storage_settings` - The object storage settings
108    ///
109    /// # Returns
110    /// * A vector of drift records
111    #[instrument(skip_all)]
112    async fn get_archived_records(
113        params: &DriftRequest,
114        begin: DateTime<Utc>,
115        end: DateTime<Utc>,
116        minutes: i32,
117        storage_settings: &ObjectStorageSettings,
118        entity_id: &i32,
119    ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
120        let path = format!("{}/psi", params.uid);
121        let bin = minutes as f64 / params.max_data_points as f64;
122
123        let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::Psi)?
124            .get_binned_metrics(&path, &bin, &begin, &end, entity_id)
125            .await?;
126
127        Ok(dataframe_to_psi_drift_features(archived_df).await?)
128    }
129
130    /// Helper for merging current and archived binned PSI drift records.
131    fn merge_feature_results(
132        results: Vec<FeatureBinProportionResult>,
133        feature_map: &mut BTreeMap<String, FeatureBinProportionResult>,
134    ) -> Result<(), SqlError> {
135        for result in results {
136            feature_map
137                .entry(result.feature.clone())
138                .and_modify(|existing| {
139                    existing.created_at.extend(result.created_at.iter());
140                    existing
141                        .bin_proportions
142                        .extend(result.bin_proportions.iter().cloned());
143
144                    for (k, v) in result.overall_proportions.iter() {
145                        existing
146                            .overall_proportions
147                            .entry(*k)
148                            .and_modify(|existing_value| {
149                                *existing_value = (*existing_value + *v) / 2.0;
150                            })
151                            .or_insert(*v);
152                    }
153                })
154                .or_insert(result);
155        }
156
157        Ok(())
158    }
159
160    // Queries the database for drift records based on a time window and aggregation.
161    // Based on the time window provided, a query or queries will be run against the short-term and
162    // archived data.
163    //
164    // # Arguments
165    //
166    // * `name` - The name of the service to query drift records for
167    // * `params` - The drift request parameters
168    // # Returns
169    //
170    // * A vector of drift records
171    #[instrument(skip_all)]
172    async fn get_binned_psi_drift_records(
173        pool: &Pool<Postgres>,
174        params: &DriftRequest,
175        retention_period: &i32,
176        storage_settings: &ObjectStorageSettings,
177        entity_id: &i32,
178    ) -> Result<Vec<FeatureBinProportionResult>, SqlError> {
179        if !params.has_custom_interval() {
180            debug!("No custom interval provided, using default");
181            let (start_dt, end_dt) = params.time_interval.to_begin_end_times()?;
182            return Self::get_records(pool, params, start_dt, end_dt, entity_id).await;
183        }
184
185        debug!("Custom interval provided, using custom interval");
186        let interval = params.clone().to_custom_interval().unwrap();
187        let timestamps = split_custom_interval(interval.begin, interval.end, retention_period)?;
188        let mut feature_map = BTreeMap::new();
189
190        // Get current records if available
191        if let Some((active_begin, active_end)) = timestamps.active_range {
192            let current_results =
193                Self::get_records(pool, params, active_begin, active_end, entity_id).await?;
194            Self::merge_feature_results(current_results, &mut feature_map)?;
195        }
196
197        // Get archived records if available
198        if let Some((archive_begin, archive_end)) = timestamps.archived_range {
199            if let Some(archived_minutes) = timestamps.archived_minutes {
200                let archived_results = Self::get_archived_records(
201                    params,
202                    archive_begin,
203                    archive_end,
204                    archived_minutes,
205                    storage_settings,
206                    entity_id,
207                )
208                .await?;
209
210                Self::merge_feature_results(archived_results, &mut feature_map)?;
211            }
212        }
213        Ok(feature_map.into_values().collect())
214    }
215
216    async fn get_feature_distributions(
217        pool: &Pool<Postgres>,
218        limit_datetime: &DateTime<Utc>,
219        features_to_monitor: &[String],
220        entity_id: &i32,
221    ) -> Result<FeatureDistributions, SqlError> {
222        let query = Queries::GetFeatureBinProportions.get_query();
223
224        let rows: Vec<FeatureDistributionRow> = sqlx::query_as(query)
225            .bind(entity_id)
226            .bind(limit_datetime)
227            .bind(features_to_monitor)
228            .fetch_all(pool)
229            .await
230            .map_err(SqlError::SqlxError)?;
231
232        Ok(FeatureDistributions::from_rows(rows))
233    }
234}