1use crate::sql::traits::{
2 AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, ObservabilitySqlLogic, ProfileSqlLogic,
3 PsiSqlLogic, SpcSqlLogic, UserSqlLogic,
4};
5
6use crate::sql::error::SqlError;
7use scouter_settings::DatabaseSettings;
8
9use scouter_types::{RecordType, ServerRecords, ToDriftRecords};
10
11use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
12use std::result::Result::Ok;
13use tracing::{debug, error, info, instrument};
14
15#[derive(Debug, Clone)]
19#[allow(dead_code)]
20pub struct PostgresClient {}
21
22impl SpcSqlLogic for PostgresClient {}
23impl CustomMetricSqlLogic for PostgresClient {}
24impl PsiSqlLogic for PostgresClient {}
25impl UserSqlLogic for PostgresClient {}
26impl ProfileSqlLogic for PostgresClient {}
27impl ObservabilitySqlLogic for PostgresClient {}
28impl AlertSqlLogic for PostgresClient {}
29impl ArchiveSqlLogic for PostgresClient {}
30
31impl PostgresClient {
32 #[instrument(skip(database_settings))]
38 pub async fn create_db_pool(
39 database_settings: &DatabaseSettings,
40 ) -> Result<Pool<Postgres>, SqlError> {
41 let pool = match PgPoolOptions::new()
42 .max_connections(database_settings.max_connections)
43 .connect(&database_settings.connection_uri)
44 .await
45 {
46 Ok(pool) => {
47 info!("✅ Successfully connected to database");
48 pool
49 }
50 Err(err) => {
51 error!("🚨 Failed to connect to database {:?}", err);
52 std::process::exit(1);
53 }
54 };
55
56 if let Err(err) = Self::run_migrations(&pool).await {
58 error!("🚨 Failed to run migrations {:?}", err);
59 std::process::exit(1);
60 }
61
62 Ok(pool)
63 }
64
65 pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
66 info!("Running migrations");
67 sqlx::migrate!("src/migrations")
68 .run(pool)
69 .await
70 .map_err(SqlError::MigrateError)?;
71
72 debug!("Migrations complete");
73
74 Ok(())
75 }
76}
77
78pub struct MessageHandler {}
79
80impl MessageHandler {
81 #[instrument(skip(records), name = "Insert Server Records")]
82 pub async fn insert_server_records(
83 pool: &Pool<Postgres>,
84 records: &ServerRecords,
85 ) -> Result<(), SqlError> {
86 match records.record_type()? {
87 RecordType::Spc => {
88 debug!("SPC record count: {:?}", records.len());
89 let records = records.to_spc_drift_records()?;
90 for record in records.iter() {
91 let _ = PostgresClient::insert_spc_drift_record(pool, record)
92 .await
93 .map_err(|e| {
94 error!("Failed to insert drift record: {:?}", e);
95 });
96 }
97 }
98 RecordType::Observability => {
99 debug!("Observability record count: {:?}", records.len());
100 let records = records.to_observability_drift_records()?;
101 for record in records.iter() {
102 let _ = PostgresClient::insert_observability_record(pool, record)
103 .await
104 .map_err(|e| {
105 error!("Failed to insert observability record: {:?}", e);
106 });
107 }
108 }
109 RecordType::Psi => {
110 debug!("PSI record count: {:?}", records.len());
111 let records = records.to_psi_drift_records()?;
112 for record in records.iter() {
113 let _ = PostgresClient::insert_bin_counts(pool, record)
114 .await
115 .map_err(|e| {
116 error!("Failed to insert bin count record: {:?}", e);
117 });
118 }
119 }
120 RecordType::Custom => {
121 debug!("Custom record count: {:?}", records.len());
122 let records = records.to_custom_metric_drift_records()?;
123 for record in records.iter() {
124 let _ = PostgresClient::insert_custom_metric_value(pool, record)
125 .await
126 .map_err(|e| {
127 error!("Failed to insert bin count record: {:?}", e);
128 });
129 }
130 }
131 };
132 Ok(())
133 }
134}
135
136#[cfg(test)]
140mod tests {
141
142 use super::*;
143 use crate::sql::schema::User;
144 use chrono::Utc;
145 use rand::Rng;
146 use scouter_settings::ObjectStorageSettings;
147 use scouter_types::psi::{Bin, BinType, PsiDriftConfig, PsiFeatureDriftProfile};
148 use scouter_types::spc::SpcDriftProfile;
149 use scouter_types::*;
150 use std::collections::BTreeMap;
151
152 const SPACE: &str = "space";
153 const NAME: &str = "name";
154 const VERSION: &str = "1.0.0";
155
156 pub async fn cleanup(pool: &Pool<Postgres>) {
157 sqlx::raw_sql(
158 r#"
159 DELETE
160 FROM scouter.spc_drift;
161
162 DELETE
163 FROM scouter.observability_metric;
164
165 DELETE
166 FROM scouter.custom_drift;
167
168 DELETE
169 FROM scouter.drift_alert;
170
171 DELETE
172 FROM scouter.drift_profile;
173
174 DELETE
175 FROM scouter.psi_drift;
176
177 DELETE
178 FROM scouter.user;
179 "#,
180 )
181 .fetch_all(pool)
182 .await
183 .unwrap();
184 }
185
186 pub async fn db_pool() -> Pool<Postgres> {
187 let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
188 .await
189 .unwrap();
190
191 cleanup(&pool).await;
192
193 pool
194 }
195
196 #[tokio::test]
197 async fn test_postgres() {
198 let _pool = db_pool().await;
199 }
200
201 #[tokio::test]
202 async fn test_postgres_drift_alert() {
203 let pool = db_pool().await;
204
205 let timestamp = Utc::now();
206
207 for _ in 0..10 {
208 let task_info = DriftTaskInfo {
209 space: SPACE.to_string(),
210 name: NAME.to_string(),
211 version: VERSION.to_string(),
212 uid: "test".to_string(),
213 drift_type: DriftType::Spc,
214 };
215
216 let alert = (0..10)
217 .map(|i| (i.to_string(), i.to_string()))
218 .collect::<BTreeMap<String, String>>();
219
220 let result = PostgresClient::insert_drift_alert(
221 &pool,
222 &task_info,
223 "test",
224 &alert,
225 &DriftType::Spc,
226 )
227 .await
228 .unwrap();
229
230 assert_eq!(result.rows_affected(), 1);
231 }
232
233 let alert_request = DriftAlertRequest {
235 space: SPACE.to_string(),
236 name: NAME.to_string(),
237 version: VERSION.to_string(),
238 active: Some(true),
239 limit: None,
240 limit_datetime: None,
241 };
242
243 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
244 .await
245 .unwrap();
246 assert!(alerts.len() > 5);
247
248 let alert_request = DriftAlertRequest {
250 space: SPACE.to_string(),
251 name: NAME.to_string(),
252 version: VERSION.to_string(),
253 active: Some(true),
254 limit: Some(1),
255 limit_datetime: None,
256 };
257
258 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
259 .await
260 .unwrap();
261 assert_eq!(alerts.len(), 1);
262
263 let alert_request = DriftAlertRequest {
265 space: SPACE.to_string(),
266 name: NAME.to_string(),
267 version: VERSION.to_string(),
268 active: Some(true),
269 limit: None,
270 limit_datetime: Some(timestamp),
271 };
272
273 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
274 .await
275 .unwrap();
276 assert!(alerts.len() > 5);
277 }
278
279 #[tokio::test]
280 async fn test_postgres_spc_drift_record() {
281 let pool = db_pool().await;
282
283 let record = SpcServerRecord {
284 created_at: Utc::now(),
285 space: SPACE.to_string(),
286 name: NAME.to_string(),
287 version: VERSION.to_string(),
288 feature: "test".to_string(),
289 value: 1.0,
290 };
291
292 let result = PostgresClient::insert_spc_drift_record(&pool, &record)
293 .await
294 .unwrap();
295
296 assert_eq!(result.rows_affected(), 1);
297 }
298
299 #[tokio::test]
300 async fn test_postgres_bin_count() {
301 let pool = db_pool().await;
302
303 let record = PsiServerRecord {
304 created_at: Utc::now(),
305 space: SPACE.to_string(),
306 name: NAME.to_string(),
307 version: VERSION.to_string(),
308 feature: "test".to_string(),
309 bin_id: 1,
310 bin_count: 1,
311 };
312
313 let result = PostgresClient::insert_bin_counts(&pool, &record)
314 .await
315 .unwrap();
316
317 assert_eq!(result.rows_affected(), 1);
318 }
319
320 #[tokio::test]
321 async fn test_postgres_observability_record() {
322 let pool = db_pool().await;
323
324 let record = ObservabilityMetrics::default();
325
326 let result = PostgresClient::insert_observability_record(&pool, &record)
327 .await
328 .unwrap();
329
330 assert_eq!(result.rows_affected(), 1);
331 }
332
333 #[tokio::test]
334 async fn test_postgres_crud_drift_profile() {
335 let pool = db_pool().await;
336
337 let mut spc_profile = SpcDriftProfile::default();
338
339 let result =
340 PostgresClient::insert_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
341 .await
342 .unwrap();
343
344 assert_eq!(result.rows_affected(), 1);
345
346 spc_profile.scouter_version = "test".to_string();
347
348 let result =
349 PostgresClient::update_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
350 .await
351 .unwrap();
352
353 assert_eq!(result.rows_affected(), 1);
354
355 let profile = PostgresClient::get_drift_profile(
356 &pool,
357 &GetProfileRequest {
358 name: spc_profile.config.name.clone(),
359 space: spc_profile.config.space.clone(),
360 version: spc_profile.config.version.clone(),
361 drift_type: DriftType::Spc,
362 },
363 )
364 .await
365 .unwrap();
366
367 let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
368
369 assert_eq!(deserialized, spc_profile);
370
371 PostgresClient::update_drift_profile_status(
372 &pool,
373 &ProfileStatusRequest {
374 name: spc_profile.config.name.clone(),
375 space: spc_profile.config.space.clone(),
376 version: spc_profile.config.version.clone(),
377 active: false,
378 drift_type: Some(DriftType::Spc),
379 deactivate_others: false,
380 },
381 )
382 .await
383 .unwrap();
384 }
385
386 #[tokio::test]
387 async fn test_postgres_get_features() {
388 let pool = db_pool().await;
389
390 let timestamp = Utc::now();
391
392 for _ in 0..10 {
393 for j in 0..10 {
394 let record = SpcServerRecord {
395 created_at: Utc::now(),
396 space: SPACE.to_string(),
397 name: NAME.to_string(),
398 version: VERSION.to_string(),
399 feature: format!("test{}", j),
400 value: j as f64,
401 };
402
403 let result = PostgresClient::insert_spc_drift_record(&pool, &record)
404 .await
405 .unwrap();
406 assert_eq!(result.rows_affected(), 1);
407 }
408 }
409
410 let service_info = ServiceInfo {
411 space: SPACE.to_string(),
412 name: NAME.to_string(),
413 version: VERSION.to_string(),
414 };
415
416 let features = PostgresClient::get_spc_features(&pool, &service_info)
417 .await
418 .unwrap();
419 assert_eq!(features.len(), 10);
420
421 let records =
422 PostgresClient::get_spc_drift_records(&pool, &service_info, ×tamp, &features)
423 .await
424 .unwrap();
425
426 assert_eq!(records.features.len(), 10);
427
428 let binned_records = PostgresClient::get_binned_spc_drift_records(
429 &pool,
430 &DriftRequest {
431 space: SPACE.to_string(),
432 name: NAME.to_string(),
433 version: VERSION.to_string(),
434 time_interval: TimeInterval::FiveMinutes,
435 max_data_points: 10,
436 drift_type: DriftType::Spc,
437 ..Default::default()
438 },
439 &DatabaseSettings::default().retention_period,
440 &ObjectStorageSettings::default(),
441 )
442 .await
443 .unwrap();
444
445 assert_eq!(binned_records.features.len(), 10);
446 }
447
448 #[tokio::test]
449 async fn test_postgres_bin_proportions() {
450 let pool = db_pool().await;
451
452 let timestamp = Utc::now();
453
454 let num_features = 3;
455 let num_bins = 5;
456
457 let features = (0..=num_features)
458 .map(|feature| {
459 let bins = (0..=num_bins)
460 .map(|bind_id| Bin {
461 id: bind_id,
462 lower_limit: None,
463 upper_limit: None,
464 proportion: 0.0,
465 })
466 .collect();
467 let feature_name = format!("feature{}", feature);
468 let feature_profile = PsiFeatureDriftProfile {
469 id: feature_name.clone(),
470 bins,
471 timestamp,
472 bin_type: BinType::Numeric,
473 };
474 (feature_name, feature_profile)
475 })
476 .collect();
477
478 let _ = PostgresClient::insert_drift_profile(
479 &pool,
480 &DriftProfile::Psi(psi::PsiDriftProfile::new(
481 features,
482 PsiDriftConfig {
483 space: SPACE.to_string(),
484 name: NAME.to_string(),
485 version: VERSION.to_string(),
486 ..Default::default()
487 },
488 None,
489 )),
490 )
491 .await
492 .unwrap();
493
494 for feature in 0..num_features {
495 for bin in 0..=num_bins {
496 for _ in 0..=100 {
497 let record = PsiServerRecord {
498 created_at: Utc::now(),
499 space: SPACE.to_string(),
500 name: NAME.to_string(),
501 version: VERSION.to_string(),
502 feature: format!("feature{}", feature),
503 bin_id: bin,
504 bin_count: rand::rng().random_range(0..10),
505 };
506
507 PostgresClient::insert_bin_counts(&pool, &record)
508 .await
509 .unwrap();
510 }
511 }
512 }
513
514 let binned_records = PostgresClient::get_feature_distributions(
515 &pool,
516 &ServiceInfo {
517 space: SPACE.to_string(),
518 name: NAME.to_string(),
519 version: VERSION.to_string(),
520 },
521 ×tamp,
522 &["feature0".to_string()],
523 )
524 .await
525 .unwrap();
526
527 let bin_proportion = binned_records
529 .distributions
530 .get("feature0")
531 .unwrap()
532 .bins
533 .get(&1)
534 .unwrap();
535
536 assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
537
538 let binned_records = PostgresClient::get_binned_psi_drift_records(
539 &pool,
540 &DriftRequest {
541 space: SPACE.to_string(),
542 name: NAME.to_string(),
543 version: VERSION.to_string(),
544 time_interval: TimeInterval::OneHour,
545 max_data_points: 1000,
546 drift_type: DriftType::Psi,
547 ..Default::default()
548 },
549 &DatabaseSettings::default().retention_period,
550 &ObjectStorageSettings::default(),
551 )
552 .await
553 .unwrap();
554 assert_eq!(binned_records.len(), 3);
556 }
557
558 #[tokio::test]
559 async fn test_postgres_cru_custom_metric() {
560 let pool = db_pool().await;
561
562 let timestamp = Utc::now();
563
564 for i in 0..2 {
565 for _ in 0..25 {
566 let record = CustomMetricServerRecord {
567 created_at: Utc::now(),
568 space: SPACE.to_string(),
569 name: NAME.to_string(),
570 version: VERSION.to_string(),
571 metric: format!("metric{}", i),
572 value: rand::rng().random_range(0..10) as f64,
573 };
574
575 let result = PostgresClient::insert_custom_metric_value(&pool, &record)
576 .await
577 .unwrap();
578 assert_eq!(result.rows_affected(), 1);
579 }
580 }
581
582 let record = CustomMetricServerRecord {
584 created_at: Utc::now(),
585 space: SPACE.to_string(),
586 name: NAME.to_string(),
587 version: VERSION.to_string(),
588 metric: "metric3".to_string(),
589 value: rand::rng().random_range(0..10) as f64,
590 };
591
592 let result = PostgresClient::insert_custom_metric_value(&pool, &record)
593 .await
594 .unwrap();
595 assert_eq!(result.rows_affected(), 1);
596
597 let metrics = PostgresClient::get_custom_metric_values(
598 &pool,
599 &ServiceInfo {
600 space: SPACE.to_string(),
601 name: NAME.to_string(),
602 version: VERSION.to_string(),
603 },
604 ×tamp,
605 &["metric1".to_string()],
606 )
607 .await
608 .unwrap();
609
610 assert_eq!(metrics.len(), 1);
611
612 let binned_records = PostgresClient::get_binned_custom_drift_records(
613 &pool,
614 &DriftRequest {
615 space: SPACE.to_string(),
616 name: NAME.to_string(),
617 version: VERSION.to_string(),
618 time_interval: TimeInterval::OneHour,
619 max_data_points: 1000,
620 drift_type: DriftType::Custom,
621 ..Default::default()
622 },
623 &DatabaseSettings::default().retention_period,
624 &ObjectStorageSettings::default(),
625 )
626 .await
627 .unwrap();
628 assert_eq!(binned_records.metrics.len(), 3);
630 }
631
632 #[tokio::test]
633 async fn test_postgres_user() {
634 let pool = db_pool().await;
635 let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
636
637 let user = User::new(
639 "user".to_string(),
640 "pass".to_string(),
641 "email".to_string(),
642 recovery_codes,
643 None,
644 None,
645 None,
646 None,
647 );
648 PostgresClient::insert_user(&pool, &user).await.unwrap();
649
650 let mut user = PostgresClient::get_user(&pool, "user")
652 .await
653 .unwrap()
654 .unwrap();
655
656 assert_eq!(user.username, "user");
657 assert_eq!(user.group_permissions, vec!["user"]);
658 assert_eq!(user.email, "email");
659
660 user.active = false;
662 user.refresh_token = Some("token".to_string());
663
664 PostgresClient::update_user(&pool, &user).await.unwrap();
666 let user = PostgresClient::get_user(&pool, "user")
667 .await
668 .unwrap()
669 .unwrap();
670 assert!(!user.active);
671 assert_eq!(user.refresh_token.unwrap(), "token");
672
673 let users = PostgresClient::get_users(&pool).await.unwrap();
675 assert_eq!(users.len(), 1);
676
677 let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
679 assert!(!is_last_admin);
680
681 PostgresClient::delete_user(&pool, "user").await.unwrap();
683 }
684}