scouter_sql/sql/
postgres.rs

1use crate::sql::traits::{
2    AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, ObservabilitySqlLogic, ProfileSqlLogic,
3    PsiSqlLogic, SpcSqlLogic, UserSqlLogic,
4};
5
6use crate::sql::error::SqlError;
7use scouter_settings::DatabaseSettings;
8
9use scouter_types::{RecordType, ServerRecords, ToDriftRecords};
10
11use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
12use std::result::Result::Ok;
13use tracing::{debug, error, info, instrument};
14
15// TODO: Explore refactoring and breaking this out into multiple client types (i.e., spc, psi, etc.)
16// Postgres client is one of the lowest-level abstractions so it may not be worth it, as it could make server logic annoying. Worth exploring though.
17
18#[derive(Debug, Clone)]
19#[allow(dead_code)]
20pub struct PostgresClient {}
21
22impl SpcSqlLogic for PostgresClient {}
23impl CustomMetricSqlLogic for PostgresClient {}
24impl PsiSqlLogic for PostgresClient {}
25impl UserSqlLogic for PostgresClient {}
26impl ProfileSqlLogic for PostgresClient {}
27impl ObservabilitySqlLogic for PostgresClient {}
28impl AlertSqlLogic for PostgresClient {}
29impl ArchiveSqlLogic for PostgresClient {}
30
31impl PostgresClient {
32    /// Setup the application with the given database pool.
33    ///
34    /// # Returns
35    ///
36    /// * `Result<Pool<Postgres>, anyhow::Error>` - Result of the database pool
37    #[instrument(skip(database_settings))]
38    pub async fn create_db_pool(
39        database_settings: &DatabaseSettings,
40    ) -> Result<Pool<Postgres>, SqlError> {
41        let pool = match PgPoolOptions::new()
42            .max_connections(database_settings.max_connections)
43            .connect(&database_settings.connection_uri)
44            .await
45        {
46            Ok(pool) => {
47                info!("✅ Successfully connected to database");
48                pool
49            }
50            Err(err) => {
51                error!("🚨 Failed to connect to database {:?}", err);
52                std::process::exit(1);
53            }
54        };
55
56        // Run migrations
57        if let Err(err) = Self::run_migrations(&pool).await {
58            error!("🚨 Failed to run migrations {:?}", err);
59            std::process::exit(1);
60        }
61
62        Ok(pool)
63    }
64
65    pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
66        info!("Running migrations");
67        sqlx::migrate!("src/migrations")
68            .run(pool)
69            .await
70            .map_err(SqlError::MigrateError)?;
71
72        debug!("Migrations complete");
73
74        Ok(())
75    }
76}
77
78pub struct MessageHandler {}
79
80impl MessageHandler {
81    #[instrument(skip(records), name = "Insert Server Records")]
82    pub async fn insert_server_records(
83        pool: &Pool<Postgres>,
84        records: &ServerRecords,
85    ) -> Result<(), SqlError> {
86        match records.record_type()? {
87            RecordType::Spc => {
88                debug!("SPC record count: {:?}", records.len());
89                let records = records.to_spc_drift_records()?;
90                for record in records.iter() {
91                    let _ = PostgresClient::insert_spc_drift_record(pool, record)
92                        .await
93                        .map_err(|e| {
94                            error!("Failed to insert drift record: {:?}", e);
95                        });
96                }
97            }
98            RecordType::Observability => {
99                debug!("Observability record count: {:?}", records.len());
100                let records = records.to_observability_drift_records()?;
101                for record in records.iter() {
102                    let _ = PostgresClient::insert_observability_record(pool, record)
103                        .await
104                        .map_err(|e| {
105                            error!("Failed to insert observability record: {:?}", e);
106                        });
107                }
108            }
109            RecordType::Psi => {
110                debug!("PSI record count: {:?}", records.len());
111                let records = records.to_psi_drift_records()?;
112                for record in records.iter() {
113                    let _ = PostgresClient::insert_bin_counts(pool, record)
114                        .await
115                        .map_err(|e| {
116                            error!("Failed to insert bin count record: {:?}", e);
117                        });
118                }
119            }
120            RecordType::Custom => {
121                debug!("Custom record count: {:?}", records.len());
122                let records = records.to_custom_metric_drift_records()?;
123                for record in records.iter() {
124                    let _ = PostgresClient::insert_custom_metric_value(pool, record)
125                        .await
126                        .map_err(|e| {
127                            error!("Failed to insert bin count record: {:?}", e);
128                        });
129                }
130            }
131        };
132        Ok(())
133    }
134}
135
136/// Runs database integratino tests
137/// Note - binned queries targeting custom intervals with long-term and short-term data are
138/// done in the scouter-server integration tests
139#[cfg(test)]
140mod tests {
141
142    use super::*;
143    use crate::sql::schema::User;
144    use chrono::Utc;
145    use rand::Rng;
146    use scouter_settings::ObjectStorageSettings;
147    use scouter_types::psi::{Bin, BinType, PsiDriftConfig, PsiFeatureDriftProfile};
148    use scouter_types::spc::SpcDriftProfile;
149    use scouter_types::*;
150    use std::collections::BTreeMap;
151
152    const SPACE: &str = "space";
153    const NAME: &str = "name";
154    const VERSION: &str = "1.0.0";
155
156    pub async fn cleanup(pool: &Pool<Postgres>) {
157        sqlx::raw_sql(
158            r#"
159            DELETE
160            FROM scouter.spc_drift;
161
162            DELETE
163            FROM scouter.observability_metric;
164
165            DELETE
166            FROM scouter.custom_drift;
167
168            DELETE
169            FROM scouter.drift_alert;
170
171            DELETE
172            FROM scouter.drift_profile;
173
174            DELETE
175            FROM scouter.psi_drift;
176
177             DELETE
178            FROM scouter.user;
179            "#,
180        )
181        .fetch_all(pool)
182        .await
183        .unwrap();
184    }
185
186    pub async fn db_pool() -> Pool<Postgres> {
187        let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
188            .await
189            .unwrap();
190
191        cleanup(&pool).await;
192
193        pool
194    }
195
196    #[tokio::test]
197    async fn test_postgres() {
198        let _pool = db_pool().await;
199    }
200
201    #[tokio::test]
202    async fn test_postgres_drift_alert() {
203        let pool = db_pool().await;
204
205        let timestamp = Utc::now();
206
207        for _ in 0..10 {
208            let task_info = DriftTaskInfo {
209                space: SPACE.to_string(),
210                name: NAME.to_string(),
211                version: VERSION.to_string(),
212                uid: "test".to_string(),
213                drift_type: DriftType::Spc,
214            };
215
216            let alert = (0..10)
217                .map(|i| (i.to_string(), i.to_string()))
218                .collect::<BTreeMap<String, String>>();
219
220            let result = PostgresClient::insert_drift_alert(
221                &pool,
222                &task_info,
223                "test",
224                &alert,
225                &DriftType::Spc,
226            )
227            .await
228            .unwrap();
229
230            assert_eq!(result.rows_affected(), 1);
231        }
232
233        // get alerts
234        let alert_request = DriftAlertRequest {
235            space: SPACE.to_string(),
236            name: NAME.to_string(),
237            version: VERSION.to_string(),
238            active: Some(true),
239            limit: None,
240            limit_datetime: None,
241        };
242
243        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
244            .await
245            .unwrap();
246        assert!(alerts.len() > 5);
247
248        // get alerts limit 1
249        let alert_request = DriftAlertRequest {
250            space: SPACE.to_string(),
251            name: NAME.to_string(),
252            version: VERSION.to_string(),
253            active: Some(true),
254            limit: Some(1),
255            limit_datetime: None,
256        };
257
258        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
259            .await
260            .unwrap();
261        assert_eq!(alerts.len(), 1);
262
263        // get alerts limit timestamp
264        let alert_request = DriftAlertRequest {
265            space: SPACE.to_string(),
266            name: NAME.to_string(),
267            version: VERSION.to_string(),
268            active: Some(true),
269            limit: None,
270            limit_datetime: Some(timestamp),
271        };
272
273        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
274            .await
275            .unwrap();
276        assert!(alerts.len() > 5);
277    }
278
279    #[tokio::test]
280    async fn test_postgres_spc_drift_record() {
281        let pool = db_pool().await;
282
283        let record = SpcServerRecord {
284            created_at: Utc::now(),
285            space: SPACE.to_string(),
286            name: NAME.to_string(),
287            version: VERSION.to_string(),
288            feature: "test".to_string(),
289            value: 1.0,
290        };
291
292        let result = PostgresClient::insert_spc_drift_record(&pool, &record)
293            .await
294            .unwrap();
295
296        assert_eq!(result.rows_affected(), 1);
297    }
298
299    #[tokio::test]
300    async fn test_postgres_bin_count() {
301        let pool = db_pool().await;
302
303        let record = PsiServerRecord {
304            created_at: Utc::now(),
305            space: SPACE.to_string(),
306            name: NAME.to_string(),
307            version: VERSION.to_string(),
308            feature: "test".to_string(),
309            bin_id: 1,
310            bin_count: 1,
311        };
312
313        let result = PostgresClient::insert_bin_counts(&pool, &record)
314            .await
315            .unwrap();
316
317        assert_eq!(result.rows_affected(), 1);
318    }
319
320    #[tokio::test]
321    async fn test_postgres_observability_record() {
322        let pool = db_pool().await;
323
324        let record = ObservabilityMetrics::default();
325
326        let result = PostgresClient::insert_observability_record(&pool, &record)
327            .await
328            .unwrap();
329
330        assert_eq!(result.rows_affected(), 1);
331    }
332
333    #[tokio::test]
334    async fn test_postgres_crud_drift_profile() {
335        let pool = db_pool().await;
336
337        let mut spc_profile = SpcDriftProfile::default();
338
339        let result =
340            PostgresClient::insert_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
341                .await
342                .unwrap();
343
344        assert_eq!(result.rows_affected(), 1);
345
346        spc_profile.scouter_version = "test".to_string();
347
348        let result =
349            PostgresClient::update_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
350                .await
351                .unwrap();
352
353        assert_eq!(result.rows_affected(), 1);
354
355        let profile = PostgresClient::get_drift_profile(
356            &pool,
357            &GetProfileRequest {
358                name: spc_profile.config.name.clone(),
359                space: spc_profile.config.space.clone(),
360                version: spc_profile.config.version.clone(),
361                drift_type: DriftType::Spc,
362            },
363        )
364        .await
365        .unwrap();
366
367        let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
368
369        assert_eq!(deserialized, spc_profile);
370
371        PostgresClient::update_drift_profile_status(
372            &pool,
373            &ProfileStatusRequest {
374                name: spc_profile.config.name.clone(),
375                space: spc_profile.config.space.clone(),
376                version: spc_profile.config.version.clone(),
377                active: false,
378                drift_type: Some(DriftType::Spc),
379                deactivate_others: false,
380            },
381        )
382        .await
383        .unwrap();
384    }
385
386    #[tokio::test]
387    async fn test_postgres_get_features() {
388        let pool = db_pool().await;
389
390        let timestamp = Utc::now();
391
392        for _ in 0..10 {
393            for j in 0..10 {
394                let record = SpcServerRecord {
395                    created_at: Utc::now(),
396                    space: SPACE.to_string(),
397                    name: NAME.to_string(),
398                    version: VERSION.to_string(),
399                    feature: format!("test{}", j),
400                    value: j as f64,
401                };
402
403                let result = PostgresClient::insert_spc_drift_record(&pool, &record)
404                    .await
405                    .unwrap();
406                assert_eq!(result.rows_affected(), 1);
407            }
408        }
409
410        let service_info = ServiceInfo {
411            space: SPACE.to_string(),
412            name: NAME.to_string(),
413            version: VERSION.to_string(),
414        };
415
416        let features = PostgresClient::get_spc_features(&pool, &service_info)
417            .await
418            .unwrap();
419        assert_eq!(features.len(), 10);
420
421        let records =
422            PostgresClient::get_spc_drift_records(&pool, &service_info, &timestamp, &features)
423                .await
424                .unwrap();
425
426        assert_eq!(records.features.len(), 10);
427
428        let binned_records = PostgresClient::get_binned_spc_drift_records(
429            &pool,
430            &DriftRequest {
431                space: SPACE.to_string(),
432                name: NAME.to_string(),
433                version: VERSION.to_string(),
434                time_interval: TimeInterval::FiveMinutes,
435                max_data_points: 10,
436                drift_type: DriftType::Spc,
437                ..Default::default()
438            },
439            &DatabaseSettings::default().retention_period,
440            &ObjectStorageSettings::default(),
441        )
442        .await
443        .unwrap();
444
445        assert_eq!(binned_records.features.len(), 10);
446    }
447
448    #[tokio::test]
449    async fn test_postgres_bin_proportions() {
450        let pool = db_pool().await;
451
452        let timestamp = Utc::now();
453
454        let num_features = 3;
455        let num_bins = 5;
456
457        let features = (0..=num_features)
458            .map(|feature| {
459                let bins = (0..=num_bins)
460                    .map(|bind_id| Bin {
461                        id: bind_id,
462                        lower_limit: None,
463                        upper_limit: None,
464                        proportion: 0.0,
465                    })
466                    .collect();
467                let feature_name = format!("feature{}", feature);
468                let feature_profile = PsiFeatureDriftProfile {
469                    id: feature_name.clone(),
470                    bins,
471                    timestamp,
472                    bin_type: BinType::Numeric,
473                };
474                (feature_name, feature_profile)
475            })
476            .collect();
477
478        let _ = PostgresClient::insert_drift_profile(
479            &pool,
480            &DriftProfile::Psi(psi::PsiDriftProfile::new(
481                features,
482                PsiDriftConfig {
483                    space: SPACE.to_string(),
484                    name: NAME.to_string(),
485                    version: VERSION.to_string(),
486                    ..Default::default()
487                },
488                None,
489            )),
490        )
491        .await
492        .unwrap();
493
494        for feature in 0..num_features {
495            for bin in 0..=num_bins {
496                for _ in 0..=100 {
497                    let record = PsiServerRecord {
498                        created_at: Utc::now(),
499                        space: SPACE.to_string(),
500                        name: NAME.to_string(),
501                        version: VERSION.to_string(),
502                        feature: format!("feature{}", feature),
503                        bin_id: bin,
504                        bin_count: rand::rng().random_range(0..10),
505                    };
506
507                    PostgresClient::insert_bin_counts(&pool, &record)
508                        .await
509                        .unwrap();
510                }
511            }
512        }
513
514        let binned_records = PostgresClient::get_feature_distributions(
515            &pool,
516            &ServiceInfo {
517                space: SPACE.to_string(),
518                name: NAME.to_string(),
519                version: VERSION.to_string(),
520            },
521            &timestamp,
522            &["feature0".to_string()],
523        )
524        .await
525        .unwrap();
526
527        // assert binned_records.features["test"]["decile_1"] is around .5
528        let bin_proportion = binned_records
529            .distributions
530            .get("feature0")
531            .unwrap()
532            .bins
533            .get(&1)
534            .unwrap();
535
536        assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
537
538        let binned_records = PostgresClient::get_binned_psi_drift_records(
539            &pool,
540            &DriftRequest {
541                space: SPACE.to_string(),
542                name: NAME.to_string(),
543                version: VERSION.to_string(),
544                time_interval: TimeInterval::OneHour,
545                max_data_points: 1000,
546                drift_type: DriftType::Psi,
547                ..Default::default()
548            },
549            &DatabaseSettings::default().retention_period,
550            &ObjectStorageSettings::default(),
551        )
552        .await
553        .unwrap();
554        //
555        assert_eq!(binned_records.len(), 3);
556    }
557
558    #[tokio::test]
559    async fn test_postgres_cru_custom_metric() {
560        let pool = db_pool().await;
561
562        let timestamp = Utc::now();
563
564        for i in 0..2 {
565            for _ in 0..25 {
566                let record = CustomMetricServerRecord {
567                    created_at: Utc::now(),
568                    space: SPACE.to_string(),
569                    name: NAME.to_string(),
570                    version: VERSION.to_string(),
571                    metric: format!("metric{}", i),
572                    value: rand::rng().random_range(0..10) as f64,
573                };
574
575                let result = PostgresClient::insert_custom_metric_value(&pool, &record)
576                    .await
577                    .unwrap();
578                assert_eq!(result.rows_affected(), 1);
579            }
580        }
581
582        // insert random record to test has statistics funcs handle single record
583        let record = CustomMetricServerRecord {
584            created_at: Utc::now(),
585            space: SPACE.to_string(),
586            name: NAME.to_string(),
587            version: VERSION.to_string(),
588            metric: "metric3".to_string(),
589            value: rand::rng().random_range(0..10) as f64,
590        };
591
592        let result = PostgresClient::insert_custom_metric_value(&pool, &record)
593            .await
594            .unwrap();
595        assert_eq!(result.rows_affected(), 1);
596
597        let metrics = PostgresClient::get_custom_metric_values(
598            &pool,
599            &ServiceInfo {
600                space: SPACE.to_string(),
601                name: NAME.to_string(),
602                version: VERSION.to_string(),
603            },
604            &timestamp,
605            &["metric1".to_string()],
606        )
607        .await
608        .unwrap();
609
610        assert_eq!(metrics.len(), 1);
611
612        let binned_records = PostgresClient::get_binned_custom_drift_records(
613            &pool,
614            &DriftRequest {
615                space: SPACE.to_string(),
616                name: NAME.to_string(),
617                version: VERSION.to_string(),
618                time_interval: TimeInterval::OneHour,
619                max_data_points: 1000,
620                drift_type: DriftType::Custom,
621                ..Default::default()
622            },
623            &DatabaseSettings::default().retention_period,
624            &ObjectStorageSettings::default(),
625        )
626        .await
627        .unwrap();
628        //
629        assert_eq!(binned_records.metrics.len(), 3);
630    }
631
632    #[tokio::test]
633    async fn test_postgres_user() {
634        let pool = db_pool().await;
635        let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
636
637        // Create
638        let user = User::new(
639            "user".to_string(),
640            "pass".to_string(),
641            "email".to_string(),
642            recovery_codes,
643            None,
644            None,
645            None,
646            None,
647        );
648        PostgresClient::insert_user(&pool, &user).await.unwrap();
649
650        // Read
651        let mut user = PostgresClient::get_user(&pool, "user")
652            .await
653            .unwrap()
654            .unwrap();
655
656        assert_eq!(user.username, "user");
657        assert_eq!(user.group_permissions, vec!["user"]);
658        assert_eq!(user.email, "email");
659
660        // update user
661        user.active = false;
662        user.refresh_token = Some("token".to_string());
663
664        // Update
665        PostgresClient::update_user(&pool, &user).await.unwrap();
666        let user = PostgresClient::get_user(&pool, "user")
667            .await
668            .unwrap()
669            .unwrap();
670        assert!(!user.active);
671        assert_eq!(user.refresh_token.unwrap(), "token");
672
673        // get users
674        let users = PostgresClient::get_users(&pool).await.unwrap();
675        assert_eq!(users.len(), 1);
676
677        // get last admin
678        let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
679        assert!(!is_last_admin);
680
681        // delete
682        PostgresClient::delete_user(&pool, "user").await.unwrap();
683    }
684}