scouter_sql/sql/
postgres.rs

1use crate::sql::traits::{
2    AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, ObservabilitySqlLogic, ProfileSqlLogic,
3    PsiSqlLogic, SpcSqlLogic, UserSqlLogic,
4};
5
6use crate::sql::error::SqlError;
7use scouter_settings::DatabaseSettings;
8
9use scouter_types::{RecordType, ServerRecords, ToDriftRecords};
10
11use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
12use std::result::Result::Ok;
13use tracing::{debug, error, info, instrument};
14
15// TODO: Explore refactoring and breaking this out into multiple client types (i.e., spc, psi, etc.)
16// Postgres client is one of the lowest-level abstractions so it may not be worth it, as it could make server logic annoying. Worth exploring though.
17
18#[derive(Debug, Clone)]
19#[allow(dead_code)]
20pub struct PostgresClient {}
21
22impl SpcSqlLogic for PostgresClient {}
23impl CustomMetricSqlLogic for PostgresClient {}
24impl PsiSqlLogic for PostgresClient {}
25impl UserSqlLogic for PostgresClient {}
26impl ProfileSqlLogic for PostgresClient {}
27impl ObservabilitySqlLogic for PostgresClient {}
28impl AlertSqlLogic for PostgresClient {}
29impl ArchiveSqlLogic for PostgresClient {}
30
31impl PostgresClient {
32    /// Setup the application with the given database pool.
33    ///
34    /// # Returns
35    ///
36    /// * `Result<Pool<Postgres>, anyhow::Error>` - Result of the database pool
37    #[instrument(skip(database_settings))]
38    pub async fn create_db_pool(
39        database_settings: &DatabaseSettings,
40    ) -> Result<Pool<Postgres>, SqlError> {
41        let pool = match PgPoolOptions::new()
42            .max_connections(database_settings.max_connections)
43            .connect(&database_settings.connection_uri)
44            .await
45        {
46            Ok(pool) => {
47                info!("✅ Successfully connected to database");
48                pool
49            }
50            Err(err) => {
51                error!("🚨 Failed to connect to database {:?}", err);
52                std::process::exit(1);
53            }
54        };
55
56        // Run migrations
57        if let Err(err) = Self::run_migrations(&pool).await {
58            error!("🚨 Failed to run migrations {:?}", err);
59            std::process::exit(1);
60        }
61
62        Ok(pool)
63    }
64
65    pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
66        info!("Running migrations");
67        sqlx::migrate!("src/migrations")
68            .run(pool)
69            .await
70            .map_err(SqlError::MigrateError)?;
71
72        debug!("Migrations complete");
73
74        Ok(())
75    }
76}
77
78pub struct MessageHandler {}
79
80impl MessageHandler {
81    #[instrument(skip_all)]
82    pub async fn insert_server_records(
83        pool: &Pool<Postgres>,
84        records: &ServerRecords,
85    ) -> Result<(), SqlError> {
86        debug!("Inserting server records: {:?}", records);
87        match records.record_type()? {
88            RecordType::Spc => {
89                debug!("SPC record count: {:?}", records.len());
90                let records = records.to_spc_drift_records()?;
91                for record in records.iter() {
92                    let _ = PostgresClient::insert_spc_drift_record(pool, record)
93                        .await
94                        .map_err(|e| {
95                            error!("Failed to insert drift record: {:?}", e);
96                        });
97                }
98            }
99            RecordType::Observability => {
100                debug!("Observability record count: {:?}", records.len());
101                let records = records.to_observability_drift_records()?;
102                for record in records.iter() {
103                    let _ = PostgresClient::insert_observability_record(pool, record)
104                        .await
105                        .map_err(|e| {
106                            error!("Failed to insert observability record: {:?}", e);
107                        });
108                }
109            }
110            RecordType::Psi => {
111                debug!("PSI record count: {:?}", records.len());
112                let records = records.to_psi_drift_records()?;
113                for record in records.iter() {
114                    let _ = PostgresClient::insert_bin_counts(pool, record)
115                        .await
116                        .map_err(|e| {
117                            error!("Failed to insert bin count record: {:?}", e);
118                        });
119                }
120            }
121            RecordType::Custom => {
122                debug!("Custom record count: {:?}", records.len());
123                let records = records.to_custom_metric_drift_records()?;
124                for record in records.iter() {
125                    let _ = PostgresClient::insert_custom_metric_value(pool, record)
126                        .await
127                        .map_err(|e| {
128                            error!("Failed to insert bin count record: {:?}", e);
129                        });
130                }
131            }
132        };
133        Ok(())
134    }
135}
136
137/// Runs database integratino tests
138/// Note - binned queries targeting custom intervals with long-term and short-term data are
139/// done in the scouter-server integration tests
140#[cfg(test)]
141mod tests {
142
143    use super::*;
144    use crate::sql::schema::User;
145    use chrono::Utc;
146    use rand::Rng;
147    use scouter_settings::ObjectStorageSettings;
148    use scouter_types::psi::{Bin, BinType, PsiDriftConfig, PsiFeatureDriftProfile};
149    use scouter_types::spc::SpcDriftProfile;
150    use scouter_types::*;
151    use std::collections::BTreeMap;
152
153    const SPACE: &str = "space";
154    const NAME: &str = "name";
155    const VERSION: &str = "1.0.0";
156
157    pub async fn cleanup(pool: &Pool<Postgres>) {
158        sqlx::raw_sql(
159            r#"
160            DELETE
161            FROM scouter.spc_drift;
162
163            DELETE
164            FROM scouter.observability_metric;
165
166            DELETE
167            FROM scouter.custom_drift;
168
169            DELETE
170            FROM scouter.drift_alert;
171
172            DELETE
173            FROM scouter.drift_profile;
174
175            DELETE
176            FROM scouter.psi_drift;
177
178             DELETE
179            FROM scouter.user;
180            "#,
181        )
182        .fetch_all(pool)
183        .await
184        .unwrap();
185    }
186
187    pub async fn db_pool() -> Pool<Postgres> {
188        let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
189            .await
190            .unwrap();
191
192        cleanup(&pool).await;
193
194        pool
195    }
196
197    #[tokio::test]
198    async fn test_postgres() {
199        let _pool = db_pool().await;
200    }
201
202    #[tokio::test]
203    async fn test_postgres_drift_alert() {
204        let pool = db_pool().await;
205
206        let timestamp = Utc::now();
207
208        for _ in 0..10 {
209            let task_info = DriftTaskInfo {
210                space: SPACE.to_string(),
211                name: NAME.to_string(),
212                version: VERSION.to_string(),
213                uid: "test".to_string(),
214                drift_type: DriftType::Spc,
215            };
216
217            let alert = (0..10)
218                .map(|i| (i.to_string(), i.to_string()))
219                .collect::<BTreeMap<String, String>>();
220
221            let result = PostgresClient::insert_drift_alert(
222                &pool,
223                &task_info,
224                "test",
225                &alert,
226                &DriftType::Spc,
227            )
228            .await
229            .unwrap();
230
231            assert_eq!(result.rows_affected(), 1);
232        }
233
234        // get alerts
235        let alert_request = DriftAlertRequest {
236            space: SPACE.to_string(),
237            name: NAME.to_string(),
238            version: VERSION.to_string(),
239            active: Some(true),
240            limit: None,
241            limit_datetime: None,
242        };
243
244        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
245            .await
246            .unwrap();
247        assert!(alerts.len() > 5);
248
249        // get alerts limit 1
250        let alert_request = DriftAlertRequest {
251            space: SPACE.to_string(),
252            name: NAME.to_string(),
253            version: VERSION.to_string(),
254            active: Some(true),
255            limit: Some(1),
256            limit_datetime: None,
257        };
258
259        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
260            .await
261            .unwrap();
262        assert_eq!(alerts.len(), 1);
263
264        // get alerts limit timestamp
265        let alert_request = DriftAlertRequest {
266            space: SPACE.to_string(),
267            name: NAME.to_string(),
268            version: VERSION.to_string(),
269            active: Some(true),
270            limit: None,
271            limit_datetime: Some(timestamp),
272        };
273
274        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
275            .await
276            .unwrap();
277        assert!(alerts.len() > 5);
278    }
279
280    #[tokio::test]
281    async fn test_postgres_spc_drift_record() {
282        let pool = db_pool().await;
283
284        let record = SpcServerRecord {
285            created_at: Utc::now(),
286            space: SPACE.to_string(),
287            name: NAME.to_string(),
288            version: VERSION.to_string(),
289            feature: "test".to_string(),
290            value: 1.0,
291        };
292
293        let result = PostgresClient::insert_spc_drift_record(&pool, &record)
294            .await
295            .unwrap();
296
297        assert_eq!(result.rows_affected(), 1);
298    }
299
300    #[tokio::test]
301    async fn test_postgres_bin_count() {
302        let pool = db_pool().await;
303
304        let record = PsiServerRecord {
305            created_at: Utc::now(),
306            space: SPACE.to_string(),
307            name: NAME.to_string(),
308            version: VERSION.to_string(),
309            feature: "test".to_string(),
310            bin_id: 1,
311            bin_count: 1,
312        };
313
314        let result = PostgresClient::insert_bin_counts(&pool, &record)
315            .await
316            .unwrap();
317
318        assert_eq!(result.rows_affected(), 1);
319    }
320
321    #[tokio::test]
322    async fn test_postgres_observability_record() {
323        let pool = db_pool().await;
324
325        let record = ObservabilityMetrics::default();
326
327        let result = PostgresClient::insert_observability_record(&pool, &record)
328            .await
329            .unwrap();
330
331        assert_eq!(result.rows_affected(), 1);
332    }
333
334    #[tokio::test]
335    async fn test_postgres_crud_drift_profile() {
336        let pool = db_pool().await;
337
338        let mut spc_profile = SpcDriftProfile::default();
339
340        let result =
341            PostgresClient::insert_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
342                .await
343                .unwrap();
344
345        assert_eq!(result.rows_affected(), 1);
346
347        spc_profile.scouter_version = "test".to_string();
348
349        let result =
350            PostgresClient::update_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
351                .await
352                .unwrap();
353
354        assert_eq!(result.rows_affected(), 1);
355
356        let profile = PostgresClient::get_drift_profile(
357            &pool,
358            &GetProfileRequest {
359                name: spc_profile.config.name.clone(),
360                space: spc_profile.config.space.clone(),
361                version: spc_profile.config.version.clone(),
362                drift_type: DriftType::Spc,
363            },
364        )
365        .await
366        .unwrap();
367
368        let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
369
370        assert_eq!(deserialized, spc_profile);
371
372        PostgresClient::update_drift_profile_status(
373            &pool,
374            &ProfileStatusRequest {
375                name: spc_profile.config.name.clone(),
376                space: spc_profile.config.space.clone(),
377                version: spc_profile.config.version.clone(),
378                active: false,
379                drift_type: Some(DriftType::Spc),
380                deactivate_others: false,
381            },
382        )
383        .await
384        .unwrap();
385    }
386
387    #[tokio::test]
388    async fn test_postgres_get_features() {
389        let pool = db_pool().await;
390
391        let timestamp = Utc::now();
392
393        for _ in 0..10 {
394            for j in 0..10 {
395                let record = SpcServerRecord {
396                    created_at: Utc::now(),
397                    space: SPACE.to_string(),
398                    name: NAME.to_string(),
399                    version: VERSION.to_string(),
400                    feature: format!("test{}", j),
401                    value: j as f64,
402                };
403
404                let result = PostgresClient::insert_spc_drift_record(&pool, &record)
405                    .await
406                    .unwrap();
407                assert_eq!(result.rows_affected(), 1);
408            }
409        }
410
411        let service_info = ServiceInfo {
412            space: SPACE.to_string(),
413            name: NAME.to_string(),
414            version: VERSION.to_string(),
415        };
416
417        let features = PostgresClient::get_spc_features(&pool, &service_info)
418            .await
419            .unwrap();
420        assert_eq!(features.len(), 10);
421
422        let records =
423            PostgresClient::get_spc_drift_records(&pool, &service_info, &timestamp, &features)
424                .await
425                .unwrap();
426
427        assert_eq!(records.features.len(), 10);
428
429        let binned_records = PostgresClient::get_binned_spc_drift_records(
430            &pool,
431            &DriftRequest {
432                space: SPACE.to_string(),
433                name: NAME.to_string(),
434                version: VERSION.to_string(),
435                time_interval: TimeInterval::FiveMinutes,
436                max_data_points: 10,
437                drift_type: DriftType::Spc,
438                ..Default::default()
439            },
440            &DatabaseSettings::default().retention_period,
441            &ObjectStorageSettings::default(),
442        )
443        .await
444        .unwrap();
445
446        assert_eq!(binned_records.features.len(), 10);
447    }
448
449    #[tokio::test]
450    async fn test_postgres_bin_proportions() {
451        let pool = db_pool().await;
452
453        let timestamp = Utc::now();
454
455        let num_features = 3;
456        let num_bins = 5;
457
458        let features = (0..=num_features)
459            .map(|feature| {
460                let bins = (0..=num_bins)
461                    .map(|bind_id| Bin {
462                        id: bind_id,
463                        lower_limit: None,
464                        upper_limit: None,
465                        proportion: 0.0,
466                    })
467                    .collect();
468                let feature_name = format!("feature{}", feature);
469                let feature_profile = PsiFeatureDriftProfile {
470                    id: feature_name.clone(),
471                    bins,
472                    timestamp,
473                    bin_type: BinType::Numeric,
474                };
475                (feature_name, feature_profile)
476            })
477            .collect();
478
479        let _ = PostgresClient::insert_drift_profile(
480            &pool,
481            &DriftProfile::Psi(psi::PsiDriftProfile::new(
482                features,
483                PsiDriftConfig {
484                    space: SPACE.to_string(),
485                    name: NAME.to_string(),
486                    version: VERSION.to_string(),
487                    ..Default::default()
488                },
489                None,
490            )),
491        )
492        .await
493        .unwrap();
494
495        for feature in 0..num_features {
496            for bin in 0..=num_bins {
497                for _ in 0..=100 {
498                    let record = PsiServerRecord {
499                        created_at: Utc::now(),
500                        space: SPACE.to_string(),
501                        name: NAME.to_string(),
502                        version: VERSION.to_string(),
503                        feature: format!("feature{}", feature),
504                        bin_id: bin,
505                        bin_count: rand::rng().random_range(0..10),
506                    };
507
508                    PostgresClient::insert_bin_counts(&pool, &record)
509                        .await
510                        .unwrap();
511                }
512            }
513        }
514
515        let binned_records = PostgresClient::get_feature_distributions(
516            &pool,
517            &ServiceInfo {
518                space: SPACE.to_string(),
519                name: NAME.to_string(),
520                version: VERSION.to_string(),
521            },
522            &timestamp,
523            &["feature0".to_string()],
524        )
525        .await
526        .unwrap();
527
528        // assert binned_records.features["test"]["decile_1"] is around .5
529        let bin_proportion = binned_records
530            .distributions
531            .get("feature0")
532            .unwrap()
533            .bins
534            .get(&1)
535            .unwrap();
536
537        assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
538
539        let binned_records = PostgresClient::get_binned_psi_drift_records(
540            &pool,
541            &DriftRequest {
542                space: SPACE.to_string(),
543                name: NAME.to_string(),
544                version: VERSION.to_string(),
545                time_interval: TimeInterval::OneHour,
546                max_data_points: 1000,
547                drift_type: DriftType::Psi,
548                ..Default::default()
549            },
550            &DatabaseSettings::default().retention_period,
551            &ObjectStorageSettings::default(),
552        )
553        .await
554        .unwrap();
555        //
556        assert_eq!(binned_records.len(), 3);
557    }
558
559    #[tokio::test]
560    async fn test_postgres_cru_custom_metric() {
561        let pool = db_pool().await;
562
563        let timestamp = Utc::now();
564
565        for i in 0..2 {
566            for _ in 0..25 {
567                let record = CustomMetricServerRecord {
568                    created_at: Utc::now(),
569                    space: SPACE.to_string(),
570                    name: NAME.to_string(),
571                    version: VERSION.to_string(),
572                    metric: format!("metric{}", i),
573                    value: rand::rng().random_range(0..10) as f64,
574                };
575
576                let result = PostgresClient::insert_custom_metric_value(&pool, &record)
577                    .await
578                    .unwrap();
579                assert_eq!(result.rows_affected(), 1);
580            }
581        }
582
583        // insert random record to test has statistics funcs handle single record
584        let record = CustomMetricServerRecord {
585            created_at: Utc::now(),
586            space: SPACE.to_string(),
587            name: NAME.to_string(),
588            version: VERSION.to_string(),
589            metric: "metric3".to_string(),
590            value: rand::rng().random_range(0..10) as f64,
591        };
592
593        let result = PostgresClient::insert_custom_metric_value(&pool, &record)
594            .await
595            .unwrap();
596        assert_eq!(result.rows_affected(), 1);
597
598        let metrics = PostgresClient::get_custom_metric_values(
599            &pool,
600            &ServiceInfo {
601                space: SPACE.to_string(),
602                name: NAME.to_string(),
603                version: VERSION.to_string(),
604            },
605            &timestamp,
606            &["metric1".to_string()],
607        )
608        .await
609        .unwrap();
610
611        assert_eq!(metrics.len(), 1);
612
613        let binned_records = PostgresClient::get_binned_custom_drift_records(
614            &pool,
615            &DriftRequest {
616                space: SPACE.to_string(),
617                name: NAME.to_string(),
618                version: VERSION.to_string(),
619                time_interval: TimeInterval::OneHour,
620                max_data_points: 1000,
621                drift_type: DriftType::Custom,
622                ..Default::default()
623            },
624            &DatabaseSettings::default().retention_period,
625            &ObjectStorageSettings::default(),
626        )
627        .await
628        .unwrap();
629        //
630        assert_eq!(binned_records.metrics.len(), 3);
631    }
632
633    #[tokio::test]
634    async fn test_postgres_user() {
635        let pool = db_pool().await;
636        let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
637
638        // Create
639        let user = User::new(
640            "user".to_string(),
641            "pass".to_string(),
642            "email".to_string(),
643            recovery_codes,
644            None,
645            None,
646            None,
647            None,
648        );
649        PostgresClient::insert_user(&pool, &user).await.unwrap();
650
651        // Read
652        let mut user = PostgresClient::get_user(&pool, "user")
653            .await
654            .unwrap()
655            .unwrap();
656
657        assert_eq!(user.username, "user");
658        assert_eq!(user.group_permissions, vec!["user"]);
659        assert_eq!(user.email, "email");
660
661        // update user
662        user.active = false;
663        user.refresh_token = Some("token".to_string());
664
665        // Update
666        PostgresClient::update_user(&pool, &user).await.unwrap();
667        let user = PostgresClient::get_user(&pool, "user")
668            .await
669            .unwrap()
670            .unwrap();
671        assert!(!user.active);
672        assert_eq!(user.refresh_token.unwrap(), "token");
673
674        // get users
675        let users = PostgresClient::get_users(&pool).await.unwrap();
676        assert_eq!(users.len(), 1);
677
678        // get last admin
679        let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
680        assert!(!is_last_admin);
681
682        // delete
683        PostgresClient::delete_user(&pool, "user").await.unwrap();
684    }
685}