scouter_sql/sql/
postgres.rs

1use crate::sql::traits::{
2    AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, ObservabilitySqlLogic, ProfileSqlLogic,
3    PsiSqlLogic, SpcSqlLogic, UserSqlLogic,
4};
5
6use crate::sql::error::SqlError;
7use scouter_settings::DatabaseSettings;
8
9use scouter_types::{RecordType, ServerRecords, ToDriftRecords};
10
11use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
12use std::result::Result::Ok;
13use tracing::{debug, error, info, instrument};
14
15// TODO: Explore refactoring and breaking this out into multiple client types (i.e., spc, psi, etc.)
16// Postgres client is one of the lowest-level abstractions so it may not be worth it, as it could make server logic annoying. Worth exploring though.
17
18#[derive(Debug, Clone)]
19#[allow(dead_code)]
20pub struct PostgresClient {}
21
22impl SpcSqlLogic for PostgresClient {}
23impl CustomMetricSqlLogic for PostgresClient {}
24impl PsiSqlLogic for PostgresClient {}
25impl UserSqlLogic for PostgresClient {}
26impl ProfileSqlLogic for PostgresClient {}
27impl ObservabilitySqlLogic for PostgresClient {}
28impl AlertSqlLogic for PostgresClient {}
29impl ArchiveSqlLogic for PostgresClient {}
30
31impl PostgresClient {
32    /// Setup the application with the given database pool.
33    ///
34    /// # Returns
35    ///
36    /// * `Result<Pool<Postgres>, anyhow::Error>` - Result of the database pool
37    #[instrument(skip(database_settings))]
38    pub async fn create_db_pool(
39        database_settings: &DatabaseSettings,
40    ) -> Result<Pool<Postgres>, SqlError> {
41        let pool = match PgPoolOptions::new()
42            .max_connections(database_settings.max_connections)
43            .connect(&database_settings.connection_uri)
44            .await
45        {
46            Ok(pool) => {
47                info!("✅ Successfully connected to database");
48                pool
49            }
50            Err(err) => {
51                error!("🚨 Failed to connect to database {:?}", err);
52                std::process::exit(1);
53            }
54        };
55
56        // Run migrations
57        if let Err(err) = Self::run_migrations(&pool).await {
58            error!("🚨 Failed to run migrations {:?}", err);
59            std::process::exit(1);
60        }
61
62        Ok(pool)
63    }
64
65    pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
66        info!("Running migrations");
67        sqlx::migrate!("src/migrations")
68            .run(pool)
69            .await
70            .map_err(SqlError::MigrateError)?;
71
72        debug!("Migrations complete");
73
74        Ok(())
75    }
76}
77
78pub struct MessageHandler {}
79
80impl MessageHandler {
81    #[instrument(skip(records), name = "Insert Server Records")]
82    pub async fn insert_server_records(
83        pool: &Pool<Postgres>,
84        records: &ServerRecords,
85    ) -> Result<(), SqlError> {
86        match records.record_type()? {
87            RecordType::Spc => {
88                debug!("SPC record count: {:?}", records.len());
89                let records = records.to_spc_drift_records()?;
90                for record in records.iter() {
91                    let _ = PostgresClient::insert_spc_drift_record(pool, record)
92                        .await
93                        .map_err(|e| {
94                            error!("Failed to insert drift record: {:?}", e);
95                        });
96                }
97            }
98            RecordType::Observability => {
99                debug!("Observability record count: {:?}", records.len());
100                let records = records.to_observability_drift_records()?;
101                for record in records.iter() {
102                    let _ = PostgresClient::insert_observability_record(pool, record)
103                        .await
104                        .map_err(|e| {
105                            error!("Failed to insert observability record: {:?}", e);
106                        });
107                }
108            }
109            RecordType::Psi => {
110                debug!("PSI record count: {:?}", records.len());
111                let records = records.to_psi_drift_records()?;
112                for record in records.iter() {
113                    let _ = PostgresClient::insert_bin_counts(pool, record)
114                        .await
115                        .map_err(|e| {
116                            error!("Failed to insert bin count record: {:?}", e);
117                        });
118                }
119            }
120            RecordType::Custom => {
121                debug!("Custom record count: {:?}", records.len());
122                let records = records.to_custom_metric_drift_records()?;
123                for record in records.iter() {
124                    let _ = PostgresClient::insert_custom_metric_value(pool, record)
125                        .await
126                        .map_err(|e| {
127                            error!("Failed to insert bin count record: {:?}", e);
128                        });
129                }
130            }
131        };
132        Ok(())
133    }
134}
135
136/// Runs database integratino tests
137/// Note - binned queries targeting custom intervals with long-term and short-term data are
138/// done in the scouter-server integration tests
139#[cfg(test)]
140mod tests {
141
142    use super::*;
143    use crate::sql::schema::User;
144    use chrono::Utc;
145    use rand::Rng;
146    use scouter_settings::ObjectStorageSettings;
147    use scouter_types::spc::SpcDriftProfile;
148    use scouter_types::*;
149    use std::collections::BTreeMap;
150    const SPACE: &str = "space";
151    const NAME: &str = "name";
152    const VERSION: &str = "1.0.0";
153
154    pub async fn cleanup(pool: &Pool<Postgres>) {
155        sqlx::raw_sql(
156            r#"
157            DELETE
158            FROM scouter.spc_drift;
159
160            DELETE
161            FROM scouter.observability_metric;
162
163            DELETE
164            FROM scouter.custom_drift;
165
166            DELETE
167            FROM scouter.drift_alert;
168
169            DELETE
170            FROM scouter.drift_profile;
171
172            DELETE
173            FROM scouter.psi_drift;
174
175             DELETE
176            FROM scouter.user;
177            "#,
178        )
179        .fetch_all(pool)
180        .await
181        .unwrap();
182    }
183
184    pub async fn db_pool() -> Pool<Postgres> {
185        let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
186            .await
187            .unwrap();
188
189        cleanup(&pool).await;
190
191        pool
192    }
193
194    #[tokio::test]
195    async fn test_postgres() {
196        let _pool = db_pool().await;
197    }
198
199    #[tokio::test]
200    async fn test_postgres_drift_alert() {
201        let pool = db_pool().await;
202
203        let timestamp = Utc::now();
204
205        for _ in 0..10 {
206            let task_info = DriftTaskInfo {
207                space: SPACE.to_string(),
208                name: NAME.to_string(),
209                version: VERSION.to_string(),
210                uid: "test".to_string(),
211                drift_type: DriftType::Spc,
212            };
213
214            let alert = (0..10)
215                .map(|i| (i.to_string(), i.to_string()))
216                .collect::<BTreeMap<String, String>>();
217
218            let result = PostgresClient::insert_drift_alert(
219                &pool,
220                &task_info,
221                "test",
222                &alert,
223                &DriftType::Spc,
224            )
225            .await
226            .unwrap();
227
228            assert_eq!(result.rows_affected(), 1);
229        }
230
231        // get alerts
232        let alert_request = DriftAlertRequest {
233            space: SPACE.to_string(),
234            name: NAME.to_string(),
235            version: VERSION.to_string(),
236            active: Some(true),
237            limit: None,
238            limit_datetime: None,
239        };
240
241        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
242            .await
243            .unwrap();
244        assert!(alerts.len() > 5);
245
246        // get alerts limit 1
247        let alert_request = DriftAlertRequest {
248            space: SPACE.to_string(),
249            name: NAME.to_string(),
250            version: VERSION.to_string(),
251            active: Some(true),
252            limit: Some(1),
253            limit_datetime: None,
254        };
255
256        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
257            .await
258            .unwrap();
259        assert_eq!(alerts.len(), 1);
260
261        // get alerts limit timestamp
262        let alert_request = DriftAlertRequest {
263            space: SPACE.to_string(),
264            name: NAME.to_string(),
265            version: VERSION.to_string(),
266            active: Some(true),
267            limit: None,
268            limit_datetime: Some(timestamp),
269        };
270
271        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
272            .await
273            .unwrap();
274        assert!(alerts.len() > 5);
275    }
276
277    #[tokio::test]
278    async fn test_postgres_spc_drift_record() {
279        let pool = db_pool().await;
280
281        let record = SpcServerRecord {
282            created_at: Utc::now(),
283            space: SPACE.to_string(),
284            name: NAME.to_string(),
285            version: VERSION.to_string(),
286            feature: "test".to_string(),
287            value: 1.0,
288        };
289
290        let result = PostgresClient::insert_spc_drift_record(&pool, &record)
291            .await
292            .unwrap();
293
294        assert_eq!(result.rows_affected(), 1);
295    }
296
297    #[tokio::test]
298    async fn test_postgres_bin_count() {
299        let pool = db_pool().await;
300
301        let record = PsiServerRecord {
302            created_at: Utc::now(),
303            space: SPACE.to_string(),
304            name: NAME.to_string(),
305            version: VERSION.to_string(),
306            feature: "test".to_string(),
307            bin_id: 1,
308            bin_count: 1,
309        };
310
311        let result = PostgresClient::insert_bin_counts(&pool, &record)
312            .await
313            .unwrap();
314
315        assert_eq!(result.rows_affected(), 1);
316    }
317
318    #[tokio::test]
319    async fn test_postgres_observability_record() {
320        let pool = db_pool().await;
321
322        let record = ObservabilityMetrics::default();
323
324        let result = PostgresClient::insert_observability_record(&pool, &record)
325            .await
326            .unwrap();
327
328        assert_eq!(result.rows_affected(), 1);
329    }
330
331    #[tokio::test]
332    async fn test_postgres_crud_drift_profile() {
333        let pool = db_pool().await;
334
335        let mut spc_profile = SpcDriftProfile::default();
336
337        let result =
338            PostgresClient::insert_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
339                .await
340                .unwrap();
341
342        assert_eq!(result.rows_affected(), 1);
343
344        spc_profile.scouter_version = "test".to_string();
345
346        let result =
347            PostgresClient::update_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
348                .await
349                .unwrap();
350
351        assert_eq!(result.rows_affected(), 1);
352
353        let profile = PostgresClient::get_drift_profile(
354            &pool,
355            &GetProfileRequest {
356                name: spc_profile.config.name.clone(),
357                space: spc_profile.config.space.clone(),
358                version: spc_profile.config.version.clone(),
359                drift_type: DriftType::Spc,
360            },
361        )
362        .await
363        .unwrap();
364
365        let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
366
367        assert_eq!(deserialized, spc_profile);
368
369        PostgresClient::update_drift_profile_status(
370            &pool,
371            &ProfileStatusRequest {
372                name: spc_profile.config.name.clone(),
373                space: spc_profile.config.space.clone(),
374                version: spc_profile.config.version.clone(),
375                active: false,
376                drift_type: Some(DriftType::Spc),
377                deactivate_others: false,
378            },
379        )
380        .await
381        .unwrap();
382    }
383
384    #[tokio::test]
385    async fn test_postgres_get_features() {
386        let pool = db_pool().await;
387
388        let timestamp = Utc::now();
389
390        for _ in 0..10 {
391            for j in 0..10 {
392                let record = SpcServerRecord {
393                    created_at: Utc::now(),
394                    space: SPACE.to_string(),
395                    name: NAME.to_string(),
396                    version: VERSION.to_string(),
397                    feature: format!("test{}", j),
398                    value: j as f64,
399                };
400
401                let result = PostgresClient::insert_spc_drift_record(&pool, &record)
402                    .await
403                    .unwrap();
404                assert_eq!(result.rows_affected(), 1);
405            }
406        }
407
408        let service_info = ServiceInfo {
409            space: SPACE.to_string(),
410            name: NAME.to_string(),
411            version: VERSION.to_string(),
412        };
413
414        let features = PostgresClient::get_spc_features(&pool, &service_info)
415            .await
416            .unwrap();
417        assert_eq!(features.len(), 10);
418
419        let records =
420            PostgresClient::get_spc_drift_records(&pool, &service_info, &timestamp, &features)
421                .await
422                .unwrap();
423
424        assert_eq!(records.features.len(), 10);
425
426        let binned_records = PostgresClient::get_binned_spc_drift_records(
427            &pool,
428            &DriftRequest {
429                space: SPACE.to_string(),
430                name: NAME.to_string(),
431                version: VERSION.to_string(),
432                time_interval: TimeInterval::FiveMinutes,
433                max_data_points: 10,
434                drift_type: DriftType::Spc,
435                ..Default::default()
436            },
437            &DatabaseSettings::default().retention_period,
438            &ObjectStorageSettings::default(),
439        )
440        .await
441        .unwrap();
442
443        assert_eq!(binned_records.features.len(), 10);
444    }
445
446    #[tokio::test]
447    async fn test_postgres_bin_proportions() {
448        let pool = db_pool().await;
449
450        let timestamp = Utc::now();
451
452        for feature in 0..3 {
453            for bin in 0..=5 {
454                for _ in 0..=100 {
455                    let record = PsiServerRecord {
456                        created_at: Utc::now(),
457                        space: SPACE.to_string(),
458                        name: NAME.to_string(),
459                        version: VERSION.to_string(),
460                        feature: format!("feature{}", feature),
461                        bin_id: bin,
462                        bin_count: rand::rng().random_range(0..10),
463                    };
464
465                    PostgresClient::insert_bin_counts(&pool, &record)
466                        .await
467                        .unwrap();
468                }
469            }
470        }
471
472        let binned_records = PostgresClient::get_feature_bin_proportions(
473            &pool,
474            &ServiceInfo {
475                space: SPACE.to_string(),
476                name: NAME.to_string(),
477                version: VERSION.to_string(),
478            },
479            &timestamp,
480            &["feature0".to_string()],
481        )
482        .await
483        .unwrap();
484
485        // assert binned_records.features["test"]["decile_1"] is around .5
486        let bin_proportion = binned_records
487            .features
488            .get("feature0")
489            .unwrap()
490            .get(&1)
491            .unwrap();
492
493        assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
494
495        let binned_records = PostgresClient::get_binned_psi_drift_records(
496            &pool,
497            &DriftRequest {
498                space: SPACE.to_string(),
499                name: NAME.to_string(),
500                version: VERSION.to_string(),
501                time_interval: TimeInterval::OneHour,
502                max_data_points: 1000,
503                drift_type: DriftType::Psi,
504                ..Default::default()
505            },
506            &DatabaseSettings::default().retention_period,
507            &ObjectStorageSettings::default(),
508        )
509        .await
510        .unwrap();
511        //
512        assert_eq!(binned_records.len(), 3);
513    }
514
515    #[tokio::test]
516    async fn test_postgres_cru_custom_metric() {
517        let pool = db_pool().await;
518
519        let timestamp = Utc::now();
520
521        for i in 0..2 {
522            for _ in 0..25 {
523                let record = CustomMetricServerRecord {
524                    created_at: Utc::now(),
525                    space: SPACE.to_string(),
526                    name: NAME.to_string(),
527                    version: VERSION.to_string(),
528                    metric: format!("metric{}", i),
529                    value: rand::rng().random_range(0..10) as f64,
530                };
531
532                let result = PostgresClient::insert_custom_metric_value(&pool, &record)
533                    .await
534                    .unwrap();
535                assert_eq!(result.rows_affected(), 1);
536            }
537        }
538
539        // insert random record to test has statistics funcs handle single record
540        let record = CustomMetricServerRecord {
541            created_at: Utc::now(),
542            space: SPACE.to_string(),
543            name: NAME.to_string(),
544            version: VERSION.to_string(),
545            metric: "metric3".to_string(),
546            value: rand::rng().random_range(0..10) as f64,
547        };
548
549        let result = PostgresClient::insert_custom_metric_value(&pool, &record)
550            .await
551            .unwrap();
552        assert_eq!(result.rows_affected(), 1);
553
554        let metrics = PostgresClient::get_custom_metric_values(
555            &pool,
556            &ServiceInfo {
557                space: SPACE.to_string(),
558                name: NAME.to_string(),
559                version: VERSION.to_string(),
560            },
561            &timestamp,
562            &["metric1".to_string()],
563        )
564        .await
565        .unwrap();
566
567        assert_eq!(metrics.len(), 1);
568
569        let binned_records = PostgresClient::get_binned_custom_drift_records(
570            &pool,
571            &DriftRequest {
572                space: SPACE.to_string(),
573                name: NAME.to_string(),
574                version: VERSION.to_string(),
575                time_interval: TimeInterval::OneHour,
576                max_data_points: 1000,
577                drift_type: DriftType::Custom,
578                ..Default::default()
579            },
580            &DatabaseSettings::default().retention_period,
581            &ObjectStorageSettings::default(),
582        )
583        .await
584        .unwrap();
585        //
586        assert_eq!(binned_records.metrics.len(), 3);
587    }
588
589    #[tokio::test]
590    async fn test_postgres_user() {
591        let pool = db_pool().await;
592        let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
593
594        // Create
595        let user = User::new(
596            "user".to_string(),
597            "pass".to_string(),
598            "email".to_string(),
599            recovery_codes,
600            None,
601            None,
602            None,
603            None,
604        );
605        PostgresClient::insert_user(&pool, &user).await.unwrap();
606
607        // Read
608        let mut user = PostgresClient::get_user(&pool, "user")
609            .await
610            .unwrap()
611            .unwrap();
612
613        assert_eq!(user.username, "user");
614        assert_eq!(user.group_permissions, vec!["user"]);
615        assert_eq!(user.email, "email");
616
617        // update user
618        user.active = false;
619        user.refresh_token = Some("token".to_string());
620
621        // Update
622        PostgresClient::update_user(&pool, &user).await.unwrap();
623        let user = PostgresClient::get_user(&pool, "user")
624            .await
625            .unwrap()
626            .unwrap();
627        assert!(!user.active);
628        assert_eq!(user.refresh_token.unwrap(), "token");
629
630        // get users
631        let users = PostgresClient::get_users(&pool).await.unwrap();
632        assert_eq!(users.len(), 1);
633
634        // get last admin
635        let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
636        assert!(!is_last_admin);
637
638        // delete
639        PostgresClient::delete_user(&pool, "user").await.unwrap();
640    }
641}