scouter_sql/sql/
postgres.rs

1use crate::sql::traits::{
2    AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, ObservabilitySqlLogic, ProfileSqlLogic,
3    PsiSqlLogic, SpcSqlLogic, UserSqlLogic,
4};
5
6use crate::sql::error::SqlError;
7use scouter_settings::DatabaseSettings;
8
9use scouter_types::{RecordType, ServerRecords, ToDriftRecords};
10
11use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
12use std::result::Result::Ok;
13use tracing::{debug, error, info, instrument};
14
15// TODO: Explore refactoring and breaking this out into multiple client types (i.e., spc, psi, etc.)
16// Postgres client is one of the lowest-level abstractions so it may not be worth it, as it could make server logic annoying. Worth exploring though.
17
18#[derive(Debug, Clone)]
19#[allow(dead_code)]
20pub struct PostgresClient {}
21
22impl SpcSqlLogic for PostgresClient {}
23impl CustomMetricSqlLogic for PostgresClient {}
24impl PsiSqlLogic for PostgresClient {}
25impl UserSqlLogic for PostgresClient {}
26impl ProfileSqlLogic for PostgresClient {}
27impl ObservabilitySqlLogic for PostgresClient {}
28impl AlertSqlLogic for PostgresClient {}
29impl ArchiveSqlLogic for PostgresClient {}
30
31impl PostgresClient {
32    /// Setup the application with the given database pool.
33    ///
34    /// # Returns
35    ///
36    /// * `Result<Pool<Postgres>, anyhow::Error>` - Result of the database pool
37    #[instrument(skip(database_settings))]
38    pub async fn create_db_pool(
39        database_settings: &DatabaseSettings,
40    ) -> Result<Pool<Postgres>, SqlError> {
41        let pool = match PgPoolOptions::new()
42            .max_connections(database_settings.max_connections)
43            .connect(&database_settings.connection_uri)
44            .await
45        {
46            Ok(pool) => {
47                info!("✅ Successfully connected to database");
48                pool
49            }
50            Err(err) => {
51                error!("🚨 Failed to connect to database {:?}", err);
52                std::process::exit(1);
53            }
54        };
55
56        // Run migrations
57        if let Err(err) = Self::run_migrations(&pool).await {
58            error!("🚨 Failed to run migrations {:?}", err);
59            std::process::exit(1);
60        }
61
62        Ok(pool)
63    }
64
65    pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
66        info!("Running migrations");
67        sqlx::migrate!("src/migrations")
68            .run(pool)
69            .await
70            .map_err(SqlError::MigrateError)?;
71
72        debug!("Migrations complete");
73
74        Ok(())
75    }
76}
77
78pub struct MessageHandler {}
79
80impl MessageHandler {
81    const DEFAULT_BATCH_SIZE: usize = 500;
82    #[instrument(skip_all)]
83    pub async fn insert_server_records(
84        pool: &Pool<Postgres>,
85        records: &ServerRecords,
86    ) -> Result<(), SqlError> {
87        debug!("Inserting server records: {:?}", records);
88
89        match records.record_type()? {
90            RecordType::Spc => {
91                let spc_records = records.to_spc_drift_records()?;
92                debug!("SPC record count: {}", spc_records.len());
93
94                for chunk in spc_records.chunks(Self::DEFAULT_BATCH_SIZE) {
95                    PostgresClient::insert_spc_drift_records_batch(pool, chunk)
96                        .await
97                        .map_err(|e| {
98                            error!("Failed to insert SPC drift records batch: {:?}", e);
99                            e
100                        })?;
101                }
102            }
103
104            RecordType::Psi => {
105                let psi_records = records.to_psi_drift_records()?;
106                debug!("PSI record count: {}", psi_records.len());
107
108                for chunk in psi_records.chunks(Self::DEFAULT_BATCH_SIZE) {
109                    PostgresClient::insert_bin_counts_batch(pool, chunk)
110                        .await
111                        .map_err(|e| {
112                            error!("Failed to insert PSI drift records batch: {:?}", e);
113                            e
114                        })?;
115                }
116            }
117            RecordType::Custom => {
118                let custom_records = records.to_custom_metric_drift_records()?;
119                debug!("Custom record count: {}", custom_records.len());
120
121                for chunk in custom_records.chunks(Self::DEFAULT_BATCH_SIZE) {
122                    PostgresClient::insert_custom_metric_values_batch(pool, chunk)
123                        .await
124                        .map_err(|e| {
125                            error!("Failed to insert custom metric records batch: {:?}", e);
126                            e
127                        })?;
128                }
129            }
130
131            _ => {
132                error!(
133                    "Unsupported record type for batch insert: {:?}",
134                    records.record_type()?
135                );
136                return Err(SqlError::UnsupportedBatchTypeError);
137            }
138        }
139
140        Ok(())
141    }
142}
143
144/// Runs database integratino tests
145/// Note - binned queries targeting custom intervals with long-term and short-term data are
146/// done in the scouter-server integration tests
147#[cfg(test)]
148mod tests {
149
150    use super::*;
151    use crate::sql::schema::User;
152    use chrono::Utc;
153    use rand::Rng;
154    use scouter_settings::ObjectStorageSettings;
155    use scouter_types::psi::{Bin, BinType, PsiDriftConfig, PsiFeatureDriftProfile};
156    use scouter_types::spc::SpcDriftProfile;
157    use scouter_types::*;
158    use std::collections::BTreeMap;
159
160    const SPACE: &str = "space";
161    const NAME: &str = "name";
162    const VERSION: &str = "1.0.0";
163
164    pub async fn cleanup(pool: &Pool<Postgres>) {
165        sqlx::raw_sql(
166            r#"
167            DELETE
168            FROM scouter.spc_drift;
169
170            DELETE
171            FROM scouter.observability_metric;
172
173            DELETE
174            FROM scouter.custom_drift;
175
176            DELETE
177            FROM scouter.drift_alert;
178
179            DELETE
180            FROM scouter.drift_profile;
181
182            DELETE
183            FROM scouter.psi_drift;
184
185             DELETE
186            FROM scouter.user;
187            "#,
188        )
189        .fetch_all(pool)
190        .await
191        .unwrap();
192    }
193
194    pub async fn db_pool() -> Pool<Postgres> {
195        let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
196            .await
197            .unwrap();
198
199        cleanup(&pool).await;
200
201        pool
202    }
203
204    #[tokio::test]
205    async fn test_postgres() {
206        let _pool = db_pool().await;
207    }
208
209    #[tokio::test]
210    async fn test_postgres_drift_alert() {
211        let pool = db_pool().await;
212
213        let timestamp = Utc::now();
214
215        for _ in 0..10 {
216            let task_info = DriftTaskInfo {
217                space: SPACE.to_string(),
218                name: NAME.to_string(),
219                version: VERSION.to_string(),
220                uid: "test".to_string(),
221                drift_type: DriftType::Spc,
222            };
223
224            let alert = (0..10)
225                .map(|i| (i.to_string(), i.to_string()))
226                .collect::<BTreeMap<String, String>>();
227
228            let result = PostgresClient::insert_drift_alert(
229                &pool,
230                &task_info,
231                "test",
232                &alert,
233                &DriftType::Spc,
234            )
235            .await
236            .unwrap();
237
238            assert_eq!(result.rows_affected(), 1);
239        }
240
241        // get alerts
242        let alert_request = DriftAlertRequest {
243            space: SPACE.to_string(),
244            name: NAME.to_string(),
245            version: VERSION.to_string(),
246            active: Some(true),
247            limit: None,
248            limit_datetime: None,
249        };
250
251        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
252            .await
253            .unwrap();
254        assert!(alerts.len() > 5);
255
256        // get alerts limit 1
257        let alert_request = DriftAlertRequest {
258            space: SPACE.to_string(),
259            name: NAME.to_string(),
260            version: VERSION.to_string(),
261            active: Some(true),
262            limit: Some(1),
263            limit_datetime: None,
264        };
265
266        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
267            .await
268            .unwrap();
269        assert_eq!(alerts.len(), 1);
270
271        // get alerts limit timestamp
272        let alert_request = DriftAlertRequest {
273            space: SPACE.to_string(),
274            name: NAME.to_string(),
275            version: VERSION.to_string(),
276            active: Some(true),
277            limit: None,
278            limit_datetime: Some(timestamp),
279        };
280
281        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
282            .await
283            .unwrap();
284        assert!(alerts.len() > 5);
285    }
286
287    #[tokio::test]
288    async fn test_postgres_spc_drift_record() {
289        let pool = db_pool().await;
290
291        let record1 = SpcServerRecord {
292            created_at: Utc::now(),
293            space: SPACE.to_string(),
294            name: NAME.to_string(),
295            version: VERSION.to_string(),
296            feature: "test".to_string(),
297            value: 1.0,
298        };
299
300        let record2 = SpcServerRecord {
301            created_at: Utc::now(),
302            space: SPACE.to_string(),
303            name: NAME.to_string(),
304            version: VERSION.to_string(),
305            feature: "test2".to_string(),
306            value: 2.0,
307        };
308
309        let result = PostgresClient::insert_spc_drift_records_batch(&pool, &[record1, record2])
310            .await
311            .unwrap();
312
313        assert_eq!(result.rows_affected(), 2);
314    }
315
316    #[tokio::test]
317    async fn test_postgres_bin_count() {
318        let pool = db_pool().await;
319
320        let record1 = PsiServerRecord {
321            created_at: Utc::now(),
322            space: SPACE.to_string(),
323            name: NAME.to_string(),
324            version: VERSION.to_string(),
325            feature: "test".to_string(),
326            bin_id: 1,
327            bin_count: 1,
328        };
329
330        let record2 = PsiServerRecord {
331            created_at: Utc::now(),
332            space: SPACE.to_string(),
333            name: NAME.to_string(),
334            version: VERSION.to_string(),
335            feature: "test2".to_string(),
336            bin_id: 2,
337            bin_count: 2,
338        };
339
340        let result = PostgresClient::insert_bin_counts_batch(&pool, &[record1, record2])
341            .await
342            .unwrap();
343
344        assert_eq!(result.rows_affected(), 2);
345    }
346
347    #[tokio::test]
348    async fn test_postgres_observability_record() {
349        let pool = db_pool().await;
350
351        let record = ObservabilityMetrics::default();
352
353        let result = PostgresClient::insert_observability_record(&pool, &record)
354            .await
355            .unwrap();
356
357        assert_eq!(result.rows_affected(), 1);
358    }
359
360    #[tokio::test]
361    async fn test_postgres_crud_drift_profile() {
362        let pool = db_pool().await;
363
364        let mut spc_profile = SpcDriftProfile::default();
365
366        let result =
367            PostgresClient::insert_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
368                .await
369                .unwrap();
370
371        assert_eq!(result.rows_affected(), 1);
372
373        spc_profile.scouter_version = "test".to_string();
374
375        let result =
376            PostgresClient::update_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
377                .await
378                .unwrap();
379
380        assert_eq!(result.rows_affected(), 1);
381
382        let profile = PostgresClient::get_drift_profile(
383            &pool,
384            &GetProfileRequest {
385                name: spc_profile.config.name.clone(),
386                space: spc_profile.config.space.clone(),
387                version: spc_profile.config.version.clone(),
388                drift_type: DriftType::Spc,
389            },
390        )
391        .await
392        .unwrap();
393
394        let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
395
396        assert_eq!(deserialized, spc_profile);
397
398        PostgresClient::update_drift_profile_status(
399            &pool,
400            &ProfileStatusRequest {
401                name: spc_profile.config.name.clone(),
402                space: spc_profile.config.space.clone(),
403                version: spc_profile.config.version.clone(),
404                active: false,
405                drift_type: Some(DriftType::Spc),
406                deactivate_others: false,
407            },
408        )
409        .await
410        .unwrap();
411    }
412
413    #[tokio::test]
414    async fn test_postgres_get_features() {
415        let pool = db_pool().await;
416
417        let timestamp = Utc::now();
418
419        for _ in 0..10 {
420            let mut records = Vec::new();
421            for j in 0..10 {
422                let record = SpcServerRecord {
423                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
424                    space: SPACE.to_string(),
425                    name: NAME.to_string(),
426                    version: VERSION.to_string(),
427                    feature: format!("test{j}"),
428                    value: j as f64,
429                };
430
431                records.push(record);
432            }
433
434            let result = PostgresClient::insert_spc_drift_records_batch(&pool, &records)
435                .await
436                .unwrap();
437            assert_eq!(result.rows_affected(), records.len() as u64);
438        }
439
440        let service_info = ServiceInfo {
441            space: SPACE.to_string(),
442            name: NAME.to_string(),
443            version: VERSION.to_string(),
444        };
445
446        let features = PostgresClient::get_spc_features(&pool, &service_info)
447            .await
448            .unwrap();
449        assert_eq!(features.len(), 10);
450
451        let records =
452            PostgresClient::get_spc_drift_records(&pool, &service_info, &timestamp, &features)
453                .await
454                .unwrap();
455
456        assert_eq!(records.features.len(), 10);
457
458        let binned_records = PostgresClient::get_binned_spc_drift_records(
459            &pool,
460            &DriftRequest {
461                space: SPACE.to_string(),
462                name: NAME.to_string(),
463                version: VERSION.to_string(),
464                time_interval: TimeInterval::FiveMinutes,
465                max_data_points: 10,
466                drift_type: DriftType::Spc,
467                ..Default::default()
468            },
469            &DatabaseSettings::default().retention_period,
470            &ObjectStorageSettings::default(),
471        )
472        .await
473        .unwrap();
474
475        assert_eq!(binned_records.features.len(), 10);
476    }
477
478    #[tokio::test]
479    async fn test_postgres_bin_proportions() {
480        let pool = db_pool().await;
481
482        let timestamp = Utc::now();
483
484        let num_features = 3;
485        let num_bins = 5;
486
487        let features = (0..=num_features)
488            .map(|feature| {
489                let bins = (0..=num_bins)
490                    .map(|bind_id| Bin {
491                        id: bind_id,
492                        lower_limit: None,
493                        upper_limit: None,
494                        proportion: 0.0,
495                    })
496                    .collect();
497                let feature_name = format!("feature{feature}");
498                let feature_profile = PsiFeatureDriftProfile {
499                    id: feature_name.clone(),
500                    bins,
501                    timestamp,
502                    bin_type: BinType::Numeric,
503                };
504                (feature_name, feature_profile)
505            })
506            .collect();
507
508        let _ = PostgresClient::insert_drift_profile(
509            &pool,
510            &DriftProfile::Psi(psi::PsiDriftProfile::new(
511                features,
512                PsiDriftConfig {
513                    space: SPACE.to_string(),
514                    name: NAME.to_string(),
515                    version: VERSION.to_string(),
516                    ..Default::default()
517                },
518                None,
519            )),
520        )
521        .await
522        .unwrap();
523
524        for feature in 0..num_features {
525            for bin in 0..=num_bins {
526                let mut records = Vec::new();
527                for j in 0..=100 {
528                    let record = PsiServerRecord {
529                        created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
530                        space: SPACE.to_string(),
531                        name: NAME.to_string(),
532                        version: VERSION.to_string(),
533                        feature: format!("feature{feature}"),
534                        bin_id: bin,
535                        bin_count: rand::rng().random_range(0..10),
536                    };
537
538                    records.push(record);
539                }
540                PostgresClient::insert_bin_counts_batch(&pool, &records)
541                    .await
542                    .unwrap();
543            }
544        }
545
546        let binned_records = PostgresClient::get_feature_distributions(
547            &pool,
548            &ServiceInfo {
549                space: SPACE.to_string(),
550                name: NAME.to_string(),
551                version: VERSION.to_string(),
552            },
553            &timestamp,
554            &["feature0".to_string()],
555        )
556        .await
557        .unwrap();
558
559        // assert binned_records.features["test"]["decile_1"] is around .5
560        let bin_proportion = binned_records
561            .distributions
562            .get("feature0")
563            .unwrap()
564            .bins
565            .get(&1)
566            .unwrap();
567
568        assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
569
570        let binned_records = PostgresClient::get_binned_psi_drift_records(
571            &pool,
572            &DriftRequest {
573                space: SPACE.to_string(),
574                name: NAME.to_string(),
575                version: VERSION.to_string(),
576                time_interval: TimeInterval::OneHour,
577                max_data_points: 1000,
578                drift_type: DriftType::Psi,
579                ..Default::default()
580            },
581            &DatabaseSettings::default().retention_period,
582            &ObjectStorageSettings::default(),
583        )
584        .await
585        .unwrap();
586        //
587        assert_eq!(binned_records.len(), 3);
588    }
589
590    #[tokio::test]
591    async fn test_postgres_cru_custom_metric() {
592        let pool = db_pool().await;
593
594        let timestamp = Utc::now();
595
596        for i in 0..2 {
597            let mut records = Vec::new();
598            for j in 0..25 {
599                let record = CustomMetricServerRecord {
600                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
601                    space: SPACE.to_string(),
602                    name: NAME.to_string(),
603                    version: VERSION.to_string(),
604                    metric: format!("metric{i}"),
605                    value: rand::rng().random_range(0..10) as f64,
606                };
607                records.push(record);
608            }
609            let result = PostgresClient::insert_custom_metric_values_batch(&pool, &records)
610                .await
611                .unwrap();
612            assert_eq!(result.rows_affected(), 25);
613        }
614
615        // insert random record to test has statistics funcs handle single record
616        let record = CustomMetricServerRecord {
617            created_at: Utc::now(),
618            space: SPACE.to_string(),
619            name: NAME.to_string(),
620            version: VERSION.to_string(),
621            metric: "metric3".to_string(),
622            value: rand::rng().random_range(0..10) as f64,
623        };
624
625        let result = PostgresClient::insert_custom_metric_values_batch(&pool, &[record])
626            .await
627            .unwrap();
628        assert_eq!(result.rows_affected(), 1);
629
630        let metrics = PostgresClient::get_custom_metric_values(
631            &pool,
632            &ServiceInfo {
633                space: SPACE.to_string(),
634                name: NAME.to_string(),
635                version: VERSION.to_string(),
636            },
637            &timestamp,
638            &["metric1".to_string()],
639        )
640        .await
641        .unwrap();
642
643        assert_eq!(metrics.len(), 1);
644
645        let binned_records = PostgresClient::get_binned_custom_drift_records(
646            &pool,
647            &DriftRequest {
648                space: SPACE.to_string(),
649                name: NAME.to_string(),
650                version: VERSION.to_string(),
651                time_interval: TimeInterval::OneHour,
652                max_data_points: 1000,
653                drift_type: DriftType::Custom,
654                ..Default::default()
655            },
656            &DatabaseSettings::default().retention_period,
657            &ObjectStorageSettings::default(),
658        )
659        .await
660        .unwrap();
661        //
662        assert_eq!(binned_records.metrics.len(), 3);
663    }
664
665    #[tokio::test]
666    async fn test_postgres_user() {
667        let pool = db_pool().await;
668        let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
669
670        // Create
671        let user = User::new(
672            "user".to_string(),
673            "pass".to_string(),
674            "email".to_string(),
675            recovery_codes,
676            None,
677            None,
678            None,
679            None,
680        );
681        PostgresClient::insert_user(&pool, &user).await.unwrap();
682
683        // Read
684        let mut user = PostgresClient::get_user(&pool, "user")
685            .await
686            .unwrap()
687            .unwrap();
688
689        assert_eq!(user.username, "user");
690        assert_eq!(user.group_permissions, vec!["user"]);
691        assert_eq!(user.email, "email");
692
693        // update user
694        user.active = false;
695        user.refresh_token = Some("token".to_string());
696
697        // Update
698        PostgresClient::update_user(&pool, &user).await.unwrap();
699        let user = PostgresClient::get_user(&pool, "user")
700            .await
701            .unwrap()
702            .unwrap();
703        assert!(!user.active);
704        assert_eq!(user.refresh_token.unwrap(), "token");
705
706        // get users
707        let users = PostgresClient::get_users(&pool).await.unwrap();
708        assert_eq!(users.len(), 1);
709
710        // get last admin
711        let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
712        assert!(!is_last_admin);
713
714        // delete
715        PostgresClient::delete_user(&pool, "user").await.unwrap();
716    }
717}