1use crate::sql::traits::{
2 AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, ObservabilitySqlLogic, ProfileSqlLogic,
3 PsiSqlLogic, SpcSqlLogic, UserSqlLogic,
4};
5
6use crate::sql::error::SqlError;
7use scouter_settings::DatabaseSettings;
8
9use scouter_types::{RecordType, ServerRecords, ToDriftRecords};
10
11use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
12use std::result::Result::Ok;
13use tracing::{debug, error, info, instrument};
14
15#[derive(Debug, Clone)]
19#[allow(dead_code)]
20pub struct PostgresClient {}
21
22impl SpcSqlLogic for PostgresClient {}
23impl CustomMetricSqlLogic for PostgresClient {}
24impl PsiSqlLogic for PostgresClient {}
25impl UserSqlLogic for PostgresClient {}
26impl ProfileSqlLogic for PostgresClient {}
27impl ObservabilitySqlLogic for PostgresClient {}
28impl AlertSqlLogic for PostgresClient {}
29impl ArchiveSqlLogic for PostgresClient {}
30
31impl PostgresClient {
32 #[instrument(skip(database_settings))]
38 pub async fn create_db_pool(
39 database_settings: &DatabaseSettings,
40 ) -> Result<Pool<Postgres>, SqlError> {
41 let pool = match PgPoolOptions::new()
42 .max_connections(database_settings.max_connections)
43 .connect(&database_settings.connection_uri)
44 .await
45 {
46 Ok(pool) => {
47 info!("✅ Successfully connected to database");
48 pool
49 }
50 Err(err) => {
51 error!("🚨 Failed to connect to database {:?}", err);
52 std::process::exit(1);
53 }
54 };
55
56 if let Err(err) = Self::run_migrations(&pool).await {
58 error!("🚨 Failed to run migrations {:?}", err);
59 std::process::exit(1);
60 }
61
62 Ok(pool)
63 }
64
65 pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
66 info!("Running migrations");
67 sqlx::migrate!("src/migrations")
68 .run(pool)
69 .await
70 .map_err(SqlError::MigrateError)?;
71
72 debug!("Migrations complete");
73
74 Ok(())
75 }
76}
77
78pub struct MessageHandler {}
79
80impl MessageHandler {
81 #[instrument(skip_all)]
82 pub async fn insert_server_records(
83 pool: &Pool<Postgres>,
84 records: &ServerRecords,
85 ) -> Result<(), SqlError> {
86 debug!("Inserting server records: {:?}", records);
87 match records.record_type()? {
88 RecordType::Spc => {
89 debug!("SPC record count: {:?}", records.len());
90 let records = records.to_spc_drift_records()?;
91 for record in records.iter() {
92 let _ = PostgresClient::insert_spc_drift_record(pool, record)
93 .await
94 .map_err(|e| {
95 error!("Failed to insert drift record: {:?}", e);
96 });
97 }
98 }
99 RecordType::Observability => {
100 debug!("Observability record count: {:?}", records.len());
101 let records = records.to_observability_drift_records()?;
102 for record in records.iter() {
103 let _ = PostgresClient::insert_observability_record(pool, record)
104 .await
105 .map_err(|e| {
106 error!("Failed to insert observability record: {:?}", e);
107 });
108 }
109 }
110 RecordType::Psi => {
111 debug!("PSI record count: {:?}", records.len());
112 let records = records.to_psi_drift_records()?;
113 for record in records.iter() {
114 let _ = PostgresClient::insert_bin_counts(pool, record)
115 .await
116 .map_err(|e| {
117 error!("Failed to insert bin count record: {:?}", e);
118 });
119 }
120 }
121 RecordType::Custom => {
122 debug!("Custom record count: {:?}", records.len());
123 let records = records.to_custom_metric_drift_records()?;
124 for record in records.iter() {
125 let _ = PostgresClient::insert_custom_metric_value(pool, record)
126 .await
127 .map_err(|e| {
128 error!("Failed to insert bin count record: {:?}", e);
129 });
130 }
131 }
132 };
133 Ok(())
134 }
135}
136
137#[cfg(test)]
141mod tests {
142
143 use super::*;
144 use crate::sql::schema::User;
145 use chrono::Utc;
146 use rand::Rng;
147 use scouter_settings::ObjectStorageSettings;
148 use scouter_types::psi::{Bin, BinType, PsiDriftConfig, PsiFeatureDriftProfile};
149 use scouter_types::spc::SpcDriftProfile;
150 use scouter_types::*;
151 use std::collections::BTreeMap;
152
153 const SPACE: &str = "space";
154 const NAME: &str = "name";
155 const VERSION: &str = "1.0.0";
156
157 pub async fn cleanup(pool: &Pool<Postgres>) {
158 sqlx::raw_sql(
159 r#"
160 DELETE
161 FROM scouter.spc_drift;
162
163 DELETE
164 FROM scouter.observability_metric;
165
166 DELETE
167 FROM scouter.custom_drift;
168
169 DELETE
170 FROM scouter.drift_alert;
171
172 DELETE
173 FROM scouter.drift_profile;
174
175 DELETE
176 FROM scouter.psi_drift;
177
178 DELETE
179 FROM scouter.user;
180 "#,
181 )
182 .fetch_all(pool)
183 .await
184 .unwrap();
185 }
186
187 pub async fn db_pool() -> Pool<Postgres> {
188 let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
189 .await
190 .unwrap();
191
192 cleanup(&pool).await;
193
194 pool
195 }
196
197 #[tokio::test]
198 async fn test_postgres() {
199 let _pool = db_pool().await;
200 }
201
202 #[tokio::test]
203 async fn test_postgres_drift_alert() {
204 let pool = db_pool().await;
205
206 let timestamp = Utc::now();
207
208 for _ in 0..10 {
209 let task_info = DriftTaskInfo {
210 space: SPACE.to_string(),
211 name: NAME.to_string(),
212 version: VERSION.to_string(),
213 uid: "test".to_string(),
214 drift_type: DriftType::Spc,
215 };
216
217 let alert = (0..10)
218 .map(|i| (i.to_string(), i.to_string()))
219 .collect::<BTreeMap<String, String>>();
220
221 let result = PostgresClient::insert_drift_alert(
222 &pool,
223 &task_info,
224 "test",
225 &alert,
226 &DriftType::Spc,
227 )
228 .await
229 .unwrap();
230
231 assert_eq!(result.rows_affected(), 1);
232 }
233
234 let alert_request = DriftAlertRequest {
236 space: SPACE.to_string(),
237 name: NAME.to_string(),
238 version: VERSION.to_string(),
239 active: Some(true),
240 limit: None,
241 limit_datetime: None,
242 };
243
244 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
245 .await
246 .unwrap();
247 assert!(alerts.len() > 5);
248
249 let alert_request = DriftAlertRequest {
251 space: SPACE.to_string(),
252 name: NAME.to_string(),
253 version: VERSION.to_string(),
254 active: Some(true),
255 limit: Some(1),
256 limit_datetime: None,
257 };
258
259 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
260 .await
261 .unwrap();
262 assert_eq!(alerts.len(), 1);
263
264 let alert_request = DriftAlertRequest {
266 space: SPACE.to_string(),
267 name: NAME.to_string(),
268 version: VERSION.to_string(),
269 active: Some(true),
270 limit: None,
271 limit_datetime: Some(timestamp),
272 };
273
274 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
275 .await
276 .unwrap();
277 assert!(alerts.len() > 5);
278 }
279
280 #[tokio::test]
281 async fn test_postgres_spc_drift_record() {
282 let pool = db_pool().await;
283
284 let record = SpcServerRecord {
285 created_at: Utc::now(),
286 space: SPACE.to_string(),
287 name: NAME.to_string(),
288 version: VERSION.to_string(),
289 feature: "test".to_string(),
290 value: 1.0,
291 };
292
293 let result = PostgresClient::insert_spc_drift_record(&pool, &record)
294 .await
295 .unwrap();
296
297 assert_eq!(result.rows_affected(), 1);
298 }
299
300 #[tokio::test]
301 async fn test_postgres_bin_count() {
302 let pool = db_pool().await;
303
304 let record = PsiServerRecord {
305 created_at: Utc::now(),
306 space: SPACE.to_string(),
307 name: NAME.to_string(),
308 version: VERSION.to_string(),
309 feature: "test".to_string(),
310 bin_id: 1,
311 bin_count: 1,
312 };
313
314 let result = PostgresClient::insert_bin_counts(&pool, &record)
315 .await
316 .unwrap();
317
318 assert_eq!(result.rows_affected(), 1);
319 }
320
321 #[tokio::test]
322 async fn test_postgres_observability_record() {
323 let pool = db_pool().await;
324
325 let record = ObservabilityMetrics::default();
326
327 let result = PostgresClient::insert_observability_record(&pool, &record)
328 .await
329 .unwrap();
330
331 assert_eq!(result.rows_affected(), 1);
332 }
333
334 #[tokio::test]
335 async fn test_postgres_crud_drift_profile() {
336 let pool = db_pool().await;
337
338 let mut spc_profile = SpcDriftProfile::default();
339
340 let result =
341 PostgresClient::insert_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
342 .await
343 .unwrap();
344
345 assert_eq!(result.rows_affected(), 1);
346
347 spc_profile.scouter_version = "test".to_string();
348
349 let result =
350 PostgresClient::update_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
351 .await
352 .unwrap();
353
354 assert_eq!(result.rows_affected(), 1);
355
356 let profile = PostgresClient::get_drift_profile(
357 &pool,
358 &GetProfileRequest {
359 name: spc_profile.config.name.clone(),
360 space: spc_profile.config.space.clone(),
361 version: spc_profile.config.version.clone(),
362 drift_type: DriftType::Spc,
363 },
364 )
365 .await
366 .unwrap();
367
368 let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
369
370 assert_eq!(deserialized, spc_profile);
371
372 PostgresClient::update_drift_profile_status(
373 &pool,
374 &ProfileStatusRequest {
375 name: spc_profile.config.name.clone(),
376 space: spc_profile.config.space.clone(),
377 version: spc_profile.config.version.clone(),
378 active: false,
379 drift_type: Some(DriftType::Spc),
380 deactivate_others: false,
381 },
382 )
383 .await
384 .unwrap();
385 }
386
387 #[tokio::test]
388 async fn test_postgres_get_features() {
389 let pool = db_pool().await;
390
391 let timestamp = Utc::now();
392
393 for _ in 0..10 {
394 for j in 0..10 {
395 let record = SpcServerRecord {
396 created_at: Utc::now(),
397 space: SPACE.to_string(),
398 name: NAME.to_string(),
399 version: VERSION.to_string(),
400 feature: format!("test{j}"),
401 value: j as f64,
402 };
403
404 let result = PostgresClient::insert_spc_drift_record(&pool, &record)
405 .await
406 .unwrap();
407 assert_eq!(result.rows_affected(), 1);
408 }
409 }
410
411 let service_info = ServiceInfo {
412 space: SPACE.to_string(),
413 name: NAME.to_string(),
414 version: VERSION.to_string(),
415 };
416
417 let features = PostgresClient::get_spc_features(&pool, &service_info)
418 .await
419 .unwrap();
420 assert_eq!(features.len(), 10);
421
422 let records =
423 PostgresClient::get_spc_drift_records(&pool, &service_info, ×tamp, &features)
424 .await
425 .unwrap();
426
427 assert_eq!(records.features.len(), 10);
428
429 let binned_records = PostgresClient::get_binned_spc_drift_records(
430 &pool,
431 &DriftRequest {
432 space: SPACE.to_string(),
433 name: NAME.to_string(),
434 version: VERSION.to_string(),
435 time_interval: TimeInterval::FiveMinutes,
436 max_data_points: 10,
437 drift_type: DriftType::Spc,
438 ..Default::default()
439 },
440 &DatabaseSettings::default().retention_period,
441 &ObjectStorageSettings::default(),
442 )
443 .await
444 .unwrap();
445
446 assert_eq!(binned_records.features.len(), 10);
447 }
448
449 #[tokio::test]
450 async fn test_postgres_bin_proportions() {
451 let pool = db_pool().await;
452
453 let timestamp = Utc::now();
454
455 let num_features = 3;
456 let num_bins = 5;
457
458 let features = (0..=num_features)
459 .map(|feature| {
460 let bins = (0..=num_bins)
461 .map(|bind_id| Bin {
462 id: bind_id,
463 lower_limit: None,
464 upper_limit: None,
465 proportion: 0.0,
466 })
467 .collect();
468 let feature_name = format!("feature{feature}");
469 let feature_profile = PsiFeatureDriftProfile {
470 id: feature_name.clone(),
471 bins,
472 timestamp,
473 bin_type: BinType::Numeric,
474 };
475 (feature_name, feature_profile)
476 })
477 .collect();
478
479 let _ = PostgresClient::insert_drift_profile(
480 &pool,
481 &DriftProfile::Psi(psi::PsiDriftProfile::new(
482 features,
483 PsiDriftConfig {
484 space: SPACE.to_string(),
485 name: NAME.to_string(),
486 version: VERSION.to_string(),
487 ..Default::default()
488 },
489 None,
490 )),
491 )
492 .await
493 .unwrap();
494
495 for feature in 0..num_features {
496 for bin in 0..=num_bins {
497 for _ in 0..=100 {
498 let record = PsiServerRecord {
499 created_at: Utc::now(),
500 space: SPACE.to_string(),
501 name: NAME.to_string(),
502 version: VERSION.to_string(),
503 feature: format!("feature{feature}"),
504 bin_id: bin,
505 bin_count: rand::rng().random_range(0..10),
506 };
507
508 PostgresClient::insert_bin_counts(&pool, &record)
509 .await
510 .unwrap();
511 }
512 }
513 }
514
515 let binned_records = PostgresClient::get_feature_distributions(
516 &pool,
517 &ServiceInfo {
518 space: SPACE.to_string(),
519 name: NAME.to_string(),
520 version: VERSION.to_string(),
521 },
522 ×tamp,
523 &["feature0".to_string()],
524 )
525 .await
526 .unwrap();
527
528 let bin_proportion = binned_records
530 .distributions
531 .get("feature0")
532 .unwrap()
533 .bins
534 .get(&1)
535 .unwrap();
536
537 assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
538
539 let binned_records = PostgresClient::get_binned_psi_drift_records(
540 &pool,
541 &DriftRequest {
542 space: SPACE.to_string(),
543 name: NAME.to_string(),
544 version: VERSION.to_string(),
545 time_interval: TimeInterval::OneHour,
546 max_data_points: 1000,
547 drift_type: DriftType::Psi,
548 ..Default::default()
549 },
550 &DatabaseSettings::default().retention_period,
551 &ObjectStorageSettings::default(),
552 )
553 .await
554 .unwrap();
555 assert_eq!(binned_records.len(), 3);
557 }
558
559 #[tokio::test]
560 async fn test_postgres_cru_custom_metric() {
561 let pool = db_pool().await;
562
563 let timestamp = Utc::now();
564
565 for i in 0..2 {
566 for _ in 0..25 {
567 let record = CustomMetricServerRecord {
568 created_at: Utc::now(),
569 space: SPACE.to_string(),
570 name: NAME.to_string(),
571 version: VERSION.to_string(),
572 metric: format!("metric{i}"),
573 value: rand::rng().random_range(0..10) as f64,
574 };
575
576 let result = PostgresClient::insert_custom_metric_value(&pool, &record)
577 .await
578 .unwrap();
579 assert_eq!(result.rows_affected(), 1);
580 }
581 }
582
583 let record = CustomMetricServerRecord {
585 created_at: Utc::now(),
586 space: SPACE.to_string(),
587 name: NAME.to_string(),
588 version: VERSION.to_string(),
589 metric: "metric3".to_string(),
590 value: rand::rng().random_range(0..10) as f64,
591 };
592
593 let result = PostgresClient::insert_custom_metric_value(&pool, &record)
594 .await
595 .unwrap();
596 assert_eq!(result.rows_affected(), 1);
597
598 let metrics = PostgresClient::get_custom_metric_values(
599 &pool,
600 &ServiceInfo {
601 space: SPACE.to_string(),
602 name: NAME.to_string(),
603 version: VERSION.to_string(),
604 },
605 ×tamp,
606 &["metric1".to_string()],
607 )
608 .await
609 .unwrap();
610
611 assert_eq!(metrics.len(), 1);
612
613 let binned_records = PostgresClient::get_binned_custom_drift_records(
614 &pool,
615 &DriftRequest {
616 space: SPACE.to_string(),
617 name: NAME.to_string(),
618 version: VERSION.to_string(),
619 time_interval: TimeInterval::OneHour,
620 max_data_points: 1000,
621 drift_type: DriftType::Custom,
622 ..Default::default()
623 },
624 &DatabaseSettings::default().retention_period,
625 &ObjectStorageSettings::default(),
626 )
627 .await
628 .unwrap();
629 assert_eq!(binned_records.metrics.len(), 3);
631 }
632
633 #[tokio::test]
634 async fn test_postgres_user() {
635 let pool = db_pool().await;
636 let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
637
638 let user = User::new(
640 "user".to_string(),
641 "pass".to_string(),
642 "email".to_string(),
643 recovery_codes,
644 None,
645 None,
646 None,
647 None,
648 );
649 PostgresClient::insert_user(&pool, &user).await.unwrap();
650
651 let mut user = PostgresClient::get_user(&pool, "user")
653 .await
654 .unwrap()
655 .unwrap();
656
657 assert_eq!(user.username, "user");
658 assert_eq!(user.group_permissions, vec!["user"]);
659 assert_eq!(user.email, "email");
660
661 user.active = false;
663 user.refresh_token = Some("token".to_string());
664
665 PostgresClient::update_user(&pool, &user).await.unwrap();
667 let user = PostgresClient::get_user(&pool, "user")
668 .await
669 .unwrap()
670 .unwrap();
671 assert!(!user.active);
672 assert_eq!(user.refresh_token.unwrap(), "token");
673
674 let users = PostgresClient::get_users(&pool).await.unwrap();
676 assert_eq!(users.len(), 1);
677
678 let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
680 assert!(!is_last_admin);
681
682 PostgresClient::delete_user(&pool, "user").await.unwrap();
684 }
685}