1use crate::sql::traits::{
2 AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, ObservabilitySqlLogic, ProfileSqlLogic,
3 PsiSqlLogic, SpcSqlLogic, UserSqlLogic,
4};
5
6use crate::sql::error::SqlError;
7use scouter_settings::DatabaseSettings;
8
9use scouter_types::{RecordType, ServerRecords, ToDriftRecords};
10
11use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
12use std::result::Result::Ok;
13use tracing::{debug, error, info, instrument};
14
15#[derive(Debug, Clone)]
19#[allow(dead_code)]
20pub struct PostgresClient {}
21
22impl SpcSqlLogic for PostgresClient {}
23impl CustomMetricSqlLogic for PostgresClient {}
24impl PsiSqlLogic for PostgresClient {}
25impl UserSqlLogic for PostgresClient {}
26impl ProfileSqlLogic for PostgresClient {}
27impl ObservabilitySqlLogic for PostgresClient {}
28impl AlertSqlLogic for PostgresClient {}
29impl ArchiveSqlLogic for PostgresClient {}
30
31impl PostgresClient {
32 #[instrument(skip(database_settings))]
38 pub async fn create_db_pool(
39 database_settings: &DatabaseSettings,
40 ) -> Result<Pool<Postgres>, SqlError> {
41 let pool = match PgPoolOptions::new()
42 .max_connections(database_settings.max_connections)
43 .connect(&database_settings.connection_uri)
44 .await
45 {
46 Ok(pool) => {
47 info!("✅ Successfully connected to database");
48 pool
49 }
50 Err(err) => {
51 error!("🚨 Failed to connect to database {:?}", err);
52 std::process::exit(1);
53 }
54 };
55
56 if let Err(err) = Self::run_migrations(&pool).await {
58 error!("🚨 Failed to run migrations {:?}", err);
59 std::process::exit(1);
60 }
61
62 Ok(pool)
63 }
64
65 pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
66 info!("Running migrations");
67 sqlx::migrate!("src/migrations")
68 .run(pool)
69 .await
70 .map_err(SqlError::MigrateError)?;
71
72 debug!("Migrations complete");
73
74 Ok(())
75 }
76}
77
78pub struct MessageHandler {}
79
80impl MessageHandler {
81 const DEFAULT_BATCH_SIZE: usize = 500;
82 #[instrument(skip_all)]
83 pub async fn insert_server_records(
84 pool: &Pool<Postgres>,
85 records: &ServerRecords,
86 ) -> Result<(), SqlError> {
87 debug!("Inserting server records: {:?}", records);
88
89 match records.record_type()? {
90 RecordType::Spc => {
91 let spc_records = records.to_spc_drift_records()?;
92 debug!("SPC record count: {}", spc_records.len());
93
94 for chunk in spc_records.chunks(Self::DEFAULT_BATCH_SIZE) {
95 PostgresClient::insert_spc_drift_records_batch(pool, chunk)
96 .await
97 .map_err(|e| {
98 error!("Failed to insert SPC drift records batch: {:?}", e);
99 e
100 })?;
101 }
102 }
103
104 RecordType::Psi => {
105 let psi_records = records.to_psi_drift_records()?;
106 debug!("PSI record count: {}", psi_records.len());
107
108 for chunk in psi_records.chunks(Self::DEFAULT_BATCH_SIZE) {
109 PostgresClient::insert_bin_counts_batch(pool, chunk)
110 .await
111 .map_err(|e| {
112 error!("Failed to insert PSI drift records batch: {:?}", e);
113 e
114 })?;
115 }
116 }
117 RecordType::Custom => {
118 let custom_records = records.to_custom_metric_drift_records()?;
119 debug!("Custom record count: {}", custom_records.len());
120
121 for chunk in custom_records.chunks(Self::DEFAULT_BATCH_SIZE) {
122 PostgresClient::insert_custom_metric_values_batch(pool, chunk)
123 .await
124 .map_err(|e| {
125 error!("Failed to insert custom metric records batch: {:?}", e);
126 e
127 })?;
128 }
129 }
130
131 _ => {
132 error!(
133 "Unsupported record type for batch insert: {:?}",
134 records.record_type()?
135 );
136 return Err(SqlError::UnsupportedBatchTypeError);
137 }
138 }
139
140 Ok(())
141 }
142}
143
144#[cfg(test)]
148mod tests {
149
150 use super::*;
151 use crate::sql::schema::User;
152 use chrono::Utc;
153 use rand::Rng;
154 use scouter_settings::ObjectStorageSettings;
155 use scouter_types::psi::{Bin, BinType, PsiDriftConfig, PsiFeatureDriftProfile};
156 use scouter_types::spc::SpcDriftProfile;
157 use scouter_types::*;
158 use std::collections::BTreeMap;
159
160 const SPACE: &str = "space";
161 const NAME: &str = "name";
162 const VERSION: &str = "1.0.0";
163
164 pub async fn cleanup(pool: &Pool<Postgres>) {
165 sqlx::raw_sql(
166 r#"
167 DELETE
168 FROM scouter.spc_drift;
169
170 DELETE
171 FROM scouter.observability_metric;
172
173 DELETE
174 FROM scouter.custom_drift;
175
176 DELETE
177 FROM scouter.drift_alert;
178
179 DELETE
180 FROM scouter.drift_profile;
181
182 DELETE
183 FROM scouter.psi_drift;
184
185 DELETE
186 FROM scouter.user;
187 "#,
188 )
189 .fetch_all(pool)
190 .await
191 .unwrap();
192 }
193
194 pub async fn db_pool() -> Pool<Postgres> {
195 let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
196 .await
197 .unwrap();
198
199 cleanup(&pool).await;
200
201 pool
202 }
203
204 #[tokio::test]
205 async fn test_postgres() {
206 let _pool = db_pool().await;
207 }
208
209 #[tokio::test]
210 async fn test_postgres_drift_alert() {
211 let pool = db_pool().await;
212
213 let timestamp = Utc::now();
214
215 for _ in 0..10 {
216 let task_info = DriftTaskInfo {
217 space: SPACE.to_string(),
218 name: NAME.to_string(),
219 version: VERSION.to_string(),
220 uid: "test".to_string(),
221 drift_type: DriftType::Spc,
222 };
223
224 let alert = (0..10)
225 .map(|i| (i.to_string(), i.to_string()))
226 .collect::<BTreeMap<String, String>>();
227
228 let result = PostgresClient::insert_drift_alert(
229 &pool,
230 &task_info,
231 "test",
232 &alert,
233 &DriftType::Spc,
234 )
235 .await
236 .unwrap();
237
238 assert_eq!(result.rows_affected(), 1);
239 }
240
241 let alert_request = DriftAlertRequest {
243 space: SPACE.to_string(),
244 name: NAME.to_string(),
245 version: VERSION.to_string(),
246 active: Some(true),
247 limit: None,
248 limit_datetime: None,
249 };
250
251 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
252 .await
253 .unwrap();
254 assert!(alerts.len() > 5);
255
256 let alert_request = DriftAlertRequest {
258 space: SPACE.to_string(),
259 name: NAME.to_string(),
260 version: VERSION.to_string(),
261 active: Some(true),
262 limit: Some(1),
263 limit_datetime: None,
264 };
265
266 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
267 .await
268 .unwrap();
269 assert_eq!(alerts.len(), 1);
270
271 let alert_request = DriftAlertRequest {
273 space: SPACE.to_string(),
274 name: NAME.to_string(),
275 version: VERSION.to_string(),
276 active: Some(true),
277 limit: None,
278 limit_datetime: Some(timestamp),
279 };
280
281 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
282 .await
283 .unwrap();
284 assert!(alerts.len() > 5);
285 }
286
287 #[tokio::test]
288 async fn test_postgres_spc_drift_record() {
289 let pool = db_pool().await;
290
291 let record1 = SpcServerRecord {
292 created_at: Utc::now(),
293 space: SPACE.to_string(),
294 name: NAME.to_string(),
295 version: VERSION.to_string(),
296 feature: "test".to_string(),
297 value: 1.0,
298 };
299
300 let record2 = SpcServerRecord {
301 created_at: Utc::now(),
302 space: SPACE.to_string(),
303 name: NAME.to_string(),
304 version: VERSION.to_string(),
305 feature: "test2".to_string(),
306 value: 2.0,
307 };
308
309 let result = PostgresClient::insert_spc_drift_records_batch(&pool, &[record1, record2])
310 .await
311 .unwrap();
312
313 assert_eq!(result.rows_affected(), 2);
314 }
315
316 #[tokio::test]
317 async fn test_postgres_bin_count() {
318 let pool = db_pool().await;
319
320 let record1 = PsiServerRecord {
321 created_at: Utc::now(),
322 space: SPACE.to_string(),
323 name: NAME.to_string(),
324 version: VERSION.to_string(),
325 feature: "test".to_string(),
326 bin_id: 1,
327 bin_count: 1,
328 };
329
330 let record2 = PsiServerRecord {
331 created_at: Utc::now(),
332 space: SPACE.to_string(),
333 name: NAME.to_string(),
334 version: VERSION.to_string(),
335 feature: "test2".to_string(),
336 bin_id: 2,
337 bin_count: 2,
338 };
339
340 let result = PostgresClient::insert_bin_counts_batch(&pool, &[record1, record2])
341 .await
342 .unwrap();
343
344 assert_eq!(result.rows_affected(), 2);
345 }
346
347 #[tokio::test]
348 async fn test_postgres_observability_record() {
349 let pool = db_pool().await;
350
351 let record = ObservabilityMetrics::default();
352
353 let result = PostgresClient::insert_observability_record(&pool, &record)
354 .await
355 .unwrap();
356
357 assert_eq!(result.rows_affected(), 1);
358 }
359
360 #[tokio::test]
361 async fn test_postgres_crud_drift_profile() {
362 let pool = db_pool().await;
363
364 let mut spc_profile = SpcDriftProfile::default();
365
366 let result =
367 PostgresClient::insert_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
368 .await
369 .unwrap();
370
371 assert_eq!(result.rows_affected(), 1);
372
373 spc_profile.scouter_version = "test".to_string();
374
375 let result =
376 PostgresClient::update_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
377 .await
378 .unwrap();
379
380 assert_eq!(result.rows_affected(), 1);
381
382 let profile = PostgresClient::get_drift_profile(
383 &pool,
384 &GetProfileRequest {
385 name: spc_profile.config.name.clone(),
386 space: spc_profile.config.space.clone(),
387 version: spc_profile.config.version.clone(),
388 drift_type: DriftType::Spc,
389 },
390 )
391 .await
392 .unwrap();
393
394 let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
395
396 assert_eq!(deserialized, spc_profile);
397
398 PostgresClient::update_drift_profile_status(
399 &pool,
400 &ProfileStatusRequest {
401 name: spc_profile.config.name.clone(),
402 space: spc_profile.config.space.clone(),
403 version: spc_profile.config.version.clone(),
404 active: false,
405 drift_type: Some(DriftType::Spc),
406 deactivate_others: false,
407 },
408 )
409 .await
410 .unwrap();
411 }
412
413 #[tokio::test]
414 async fn test_postgres_get_features() {
415 let pool = db_pool().await;
416
417 let timestamp = Utc::now();
418
419 for _ in 0..10 {
420 let mut records = Vec::new();
421 for j in 0..10 {
422 let record = SpcServerRecord {
423 created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
424 space: SPACE.to_string(),
425 name: NAME.to_string(),
426 version: VERSION.to_string(),
427 feature: format!("test{j}"),
428 value: j as f64,
429 };
430
431 records.push(record);
432 }
433
434 let result = PostgresClient::insert_spc_drift_records_batch(&pool, &records)
435 .await
436 .unwrap();
437 assert_eq!(result.rows_affected(), records.len() as u64);
438 }
439
440 let service_info = ServiceInfo {
441 space: SPACE.to_string(),
442 name: NAME.to_string(),
443 version: VERSION.to_string(),
444 };
445
446 let features = PostgresClient::get_spc_features(&pool, &service_info)
447 .await
448 .unwrap();
449 assert_eq!(features.len(), 10);
450
451 let records =
452 PostgresClient::get_spc_drift_records(&pool, &service_info, ×tamp, &features)
453 .await
454 .unwrap();
455
456 assert_eq!(records.features.len(), 10);
457
458 let binned_records = PostgresClient::get_binned_spc_drift_records(
459 &pool,
460 &DriftRequest {
461 space: SPACE.to_string(),
462 name: NAME.to_string(),
463 version: VERSION.to_string(),
464 time_interval: TimeInterval::FiveMinutes,
465 max_data_points: 10,
466 drift_type: DriftType::Spc,
467 ..Default::default()
468 },
469 &DatabaseSettings::default().retention_period,
470 &ObjectStorageSettings::default(),
471 )
472 .await
473 .unwrap();
474
475 assert_eq!(binned_records.features.len(), 10);
476 }
477
478 #[tokio::test]
479 async fn test_postgres_bin_proportions() {
480 let pool = db_pool().await;
481
482 let timestamp = Utc::now();
483
484 let num_features = 3;
485 let num_bins = 5;
486
487 let features = (0..=num_features)
488 .map(|feature| {
489 let bins = (0..=num_bins)
490 .map(|bind_id| Bin {
491 id: bind_id,
492 lower_limit: None,
493 upper_limit: None,
494 proportion: 0.0,
495 })
496 .collect();
497 let feature_name = format!("feature{feature}");
498 let feature_profile = PsiFeatureDriftProfile {
499 id: feature_name.clone(),
500 bins,
501 timestamp,
502 bin_type: BinType::Numeric,
503 };
504 (feature_name, feature_profile)
505 })
506 .collect();
507
508 let _ = PostgresClient::insert_drift_profile(
509 &pool,
510 &DriftProfile::Psi(psi::PsiDriftProfile::new(
511 features,
512 PsiDriftConfig {
513 space: SPACE.to_string(),
514 name: NAME.to_string(),
515 version: VERSION.to_string(),
516 ..Default::default()
517 },
518 None,
519 )),
520 )
521 .await
522 .unwrap();
523
524 for feature in 0..num_features {
525 for bin in 0..=num_bins {
526 let mut records = Vec::new();
527 for j in 0..=100 {
528 let record = PsiServerRecord {
529 created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
530 space: SPACE.to_string(),
531 name: NAME.to_string(),
532 version: VERSION.to_string(),
533 feature: format!("feature{feature}"),
534 bin_id: bin,
535 bin_count: rand::rng().random_range(0..10),
536 };
537
538 records.push(record);
539 }
540 PostgresClient::insert_bin_counts_batch(&pool, &records)
541 .await
542 .unwrap();
543 }
544 }
545
546 let binned_records = PostgresClient::get_feature_distributions(
547 &pool,
548 &ServiceInfo {
549 space: SPACE.to_string(),
550 name: NAME.to_string(),
551 version: VERSION.to_string(),
552 },
553 ×tamp,
554 &["feature0".to_string()],
555 )
556 .await
557 .unwrap();
558
559 let bin_proportion = binned_records
561 .distributions
562 .get("feature0")
563 .unwrap()
564 .bins
565 .get(&1)
566 .unwrap();
567
568 assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
569
570 let binned_records = PostgresClient::get_binned_psi_drift_records(
571 &pool,
572 &DriftRequest {
573 space: SPACE.to_string(),
574 name: NAME.to_string(),
575 version: VERSION.to_string(),
576 time_interval: TimeInterval::OneHour,
577 max_data_points: 1000,
578 drift_type: DriftType::Psi,
579 ..Default::default()
580 },
581 &DatabaseSettings::default().retention_period,
582 &ObjectStorageSettings::default(),
583 )
584 .await
585 .unwrap();
586 assert_eq!(binned_records.len(), 3);
588 }
589
590 #[tokio::test]
591 async fn test_postgres_cru_custom_metric() {
592 let pool = db_pool().await;
593
594 let timestamp = Utc::now();
595
596 for i in 0..2 {
597 let mut records = Vec::new();
598 for j in 0..25 {
599 let record = CustomMetricServerRecord {
600 created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
601 space: SPACE.to_string(),
602 name: NAME.to_string(),
603 version: VERSION.to_string(),
604 metric: format!("metric{i}"),
605 value: rand::rng().random_range(0..10) as f64,
606 };
607 records.push(record);
608 }
609 let result = PostgresClient::insert_custom_metric_values_batch(&pool, &records)
610 .await
611 .unwrap();
612 assert_eq!(result.rows_affected(), 25);
613 }
614
615 let record = CustomMetricServerRecord {
617 created_at: Utc::now(),
618 space: SPACE.to_string(),
619 name: NAME.to_string(),
620 version: VERSION.to_string(),
621 metric: "metric3".to_string(),
622 value: rand::rng().random_range(0..10) as f64,
623 };
624
625 let result = PostgresClient::insert_custom_metric_values_batch(&pool, &[record])
626 .await
627 .unwrap();
628 assert_eq!(result.rows_affected(), 1);
629
630 let metrics = PostgresClient::get_custom_metric_values(
631 &pool,
632 &ServiceInfo {
633 space: SPACE.to_string(),
634 name: NAME.to_string(),
635 version: VERSION.to_string(),
636 },
637 ×tamp,
638 &["metric1".to_string()],
639 )
640 .await
641 .unwrap();
642
643 assert_eq!(metrics.len(), 1);
644
645 let binned_records = PostgresClient::get_binned_custom_drift_records(
646 &pool,
647 &DriftRequest {
648 space: SPACE.to_string(),
649 name: NAME.to_string(),
650 version: VERSION.to_string(),
651 time_interval: TimeInterval::OneHour,
652 max_data_points: 1000,
653 drift_type: DriftType::Custom,
654 ..Default::default()
655 },
656 &DatabaseSettings::default().retention_period,
657 &ObjectStorageSettings::default(),
658 )
659 .await
660 .unwrap();
661 assert_eq!(binned_records.metrics.len(), 3);
663 }
664
665 #[tokio::test]
666 async fn test_postgres_user() {
667 let pool = db_pool().await;
668 let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
669
670 let user = User::new(
672 "user".to_string(),
673 "pass".to_string(),
674 "email".to_string(),
675 recovery_codes,
676 None,
677 None,
678 None,
679 None,
680 );
681 PostgresClient::insert_user(&pool, &user).await.unwrap();
682
683 let mut user = PostgresClient::get_user(&pool, "user")
685 .await
686 .unwrap()
687 .unwrap();
688
689 assert_eq!(user.username, "user");
690 assert_eq!(user.group_permissions, vec!["user"]);
691 assert_eq!(user.email, "email");
692
693 user.active = false;
695 user.refresh_token = Some("token".to_string());
696
697 PostgresClient::update_user(&pool, &user).await.unwrap();
699 let user = PostgresClient::get_user(&pool, "user")
700 .await
701 .unwrap()
702 .unwrap();
703 assert!(!user.active);
704 assert_eq!(user.refresh_token.unwrap(), "token");
705
706 let users = PostgresClient::get_users(&pool).await.unwrap();
708 assert_eq!(users.len(), 1);
709
710 let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
712 assert!(!is_last_admin);
713
714 PostgresClient::delete_user(&pool, "user").await.unwrap();
716 }
717}