1use crate::sql::error::SqlError;
2use crate::sql::traits::{
3 AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, LLMDriftSqlLogic, ObservabilitySqlLogic,
4 ProfileSqlLogic, PsiSqlLogic, SpcSqlLogic, TagSqlLogic, TraceSqlLogic, UserSqlLogic,
5};
6use scouter_settings::DatabaseSettings;
7use scouter_types::{RecordType, ServerRecords, TagRecord, ToDriftRecords, TraceServerRecord};
8use sqlx::ConnectOptions;
9use sqlx::{postgres::PgConnectOptions, Pool, Postgres};
10use std::result::Result::Ok;
11use tokio::try_join;
12use tracing::{debug, error, info, instrument};
13
14#[derive(Debug, Clone)]
15#[allow(dead_code)]
16pub struct PostgresClient {}
17
18impl SpcSqlLogic for PostgresClient {}
19impl CustomMetricSqlLogic for PostgresClient {}
20impl PsiSqlLogic for PostgresClient {}
21impl LLMDriftSqlLogic for PostgresClient {}
22impl UserSqlLogic for PostgresClient {}
23impl ProfileSqlLogic for PostgresClient {}
24impl ObservabilitySqlLogic for PostgresClient {}
25impl AlertSqlLogic for PostgresClient {}
26impl ArchiveSqlLogic for PostgresClient {}
27impl TraceSqlLogic for PostgresClient {}
28impl TagSqlLogic for PostgresClient {}
29
30impl PostgresClient {
31 #[instrument(skip(database_settings))]
37 pub async fn create_db_pool(
38 database_settings: &DatabaseSettings,
39 ) -> Result<Pool<Postgres>, SqlError> {
40 let mut opts: PgConnectOptions = database_settings.connection_uri.parse()?;
41
42 opts = opts.log_statements(tracing::log::LevelFilter::Off);
45
46 let pool = match sqlx::postgres::PgPoolOptions::new()
47 .max_connections(database_settings.max_connections)
48 .connect_with(opts)
49 .await
50 {
51 Ok(pool) => {
52 info!("✅ Successfully connected to database");
53 pool
54 }
55 Err(err) => {
56 error!("🚨 Failed to connect to database {:?}", err);
57 std::process::exit(1);
58 }
59 };
60
61 if let Err(err) = Self::run_migrations(&pool).await {
63 error!("🚨 Failed to run migrations {:?}", err);
64 std::process::exit(1);
65 }
66
67 Ok(pool)
68 }
69
70 pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
71 info!("Running migrations");
72 sqlx::migrate!("src/migrations")
73 .run(pool)
74 .await
75 .map_err(SqlError::MigrateError)?;
76
77 debug!("Migrations complete");
78
79 Ok(())
80 }
81}
82
83pub struct MessageHandler {}
84
85impl MessageHandler {
86 const DEFAULT_BATCH_SIZE: usize = 500;
87 #[instrument(skip_all)]
88 pub async fn insert_server_records(
89 pool: &Pool<Postgres>,
90 records: &ServerRecords,
91 ) -> Result<(), SqlError> {
92 debug!("Inserting server records: {:?}", records.record_type()?);
93
94 match records.record_type()? {
95 RecordType::Spc => {
96 let spc_records = records.to_spc_drift_records()?;
97 debug!("SPC record count: {}", spc_records.len());
98
99 for chunk in spc_records.chunks(Self::DEFAULT_BATCH_SIZE) {
100 PostgresClient::insert_spc_drift_records_batch(pool, chunk)
101 .await
102 .map_err(|e| {
103 error!("Failed to insert SPC drift records batch: {:?}", e);
104 e
105 })?;
106 }
107 }
108
109 RecordType::Psi => {
110 let psi_records = records.to_psi_drift_records()?;
111 debug!("PSI record count: {}", psi_records.len());
112
113 for chunk in psi_records.chunks(Self::DEFAULT_BATCH_SIZE) {
114 PostgresClient::insert_bin_counts_batch(pool, chunk)
115 .await
116 .map_err(|e| {
117 error!("Failed to insert PSI drift records batch: {:?}", e);
118 e
119 })?;
120 }
121 }
122 RecordType::Custom => {
123 let custom_records = records.to_custom_metric_drift_records()?;
124 debug!("Custom record count: {}", custom_records.len());
125
126 for chunk in custom_records.chunks(Self::DEFAULT_BATCH_SIZE) {
127 PostgresClient::insert_custom_metric_values_batch(pool, chunk)
128 .await
129 .map_err(|e| {
130 error!("Failed to insert custom metric records batch: {:?}", e);
131 e
132 })?;
133 }
134 }
135
136 RecordType::LLMDrift => {
137 debug!("LLM Drift record count: {:?}", records.len());
138 let records = records.to_llm_drift_records()?;
139 for record in records.iter() {
140 let _ = PostgresClient::insert_llm_drift_record(pool, record)
141 .await
142 .map_err(|e| {
143 error!("Failed to insert LLM drift record: {:?}", e);
144 });
145 }
146 }
147
148 RecordType::LLMMetric => {
149 debug!("LLM Metric record count: {:?}", records.len());
150 let llm_metric_records = records.to_llm_metric_records()?;
151
152 for chunk in llm_metric_records.chunks(Self::DEFAULT_BATCH_SIZE) {
153 PostgresClient::insert_llm_metric_values_batch(pool, chunk)
154 .await
155 .map_err(|e| {
156 error!("Failed to insert LLM metric records batch: {:?}", e);
157 e
158 })?;
159 }
160 }
161
162 _ => {
163 error!(
164 "Unsupported record type for batch insert: {:?}",
165 records.record_type()?
166 );
167 return Err(SqlError::UnsupportedBatchTypeError);
168 }
169 }
170
171 Ok(())
172 }
173
174 pub async fn insert_trace_server_record(
175 pool: &Pool<Postgres>,
176 records: &TraceServerRecord,
177 ) -> Result<(), SqlError> {
178 let (trace_batch, span_batch, baggage_batch) = records.to_records()?;
179
180 let all_tags: Vec<TagRecord> = trace_batch
181 .iter()
182 .flat_map(|trace| {
183 trace.tags.iter().map(|tag| TagRecord {
184 created_at: trace.created_at,
185 entity_type: "trace".to_string(),
186 entity_id: trace.trace_id.clone(),
187 key: tag.key.clone(),
188 value: tag.value.clone(),
189 })
190 })
191 .collect();
192
193 let (trace_result, span_result, baggage_result, tag_result) = try_join!(
194 PostgresClient::upsert_trace_batch(pool, &trace_batch),
195 PostgresClient::insert_span_batch(pool, &span_batch),
196 PostgresClient::insert_trace_baggage_batch(pool, &baggage_batch),
197 async {
198 if !all_tags.is_empty() {
199 PostgresClient::insert_tag_batch(pool, &all_tags).await
200 } else {
201 Ok(sqlx::postgres::PgQueryResult::default())
202 }
203 }
204 )?;
205
206 debug!(
207 trace_rows = trace_result.rows_affected(),
208 span_rows = span_result.rows_affected(),
209 baggage_rows = baggage_result.rows_affected(),
210 tag_rows = tag_result.rows_affected(),
211 total_traces = trace_batch.len(),
212 total_spans = span_batch.len(),
213 total_baggage = baggage_batch.len(),
214 total_tags = all_tags.len(),
215 "Successfully inserted trace server records"
216 );
217 Ok(())
218 }
219
220 pub async fn insert_tag_record(
221 pool: &Pool<Postgres>,
222 record: &TagRecord,
223 ) -> Result<(), SqlError> {
224 let result = PostgresClient::insert_tag_batch(pool, std::slice::from_ref(record)).await?;
225
226 debug!(
227 rows_affected = result.rows_affected(),
228 entity_type = record.entity_type.as_str(),
229 entity_id = record.entity_id.as_str(),
230 key = record.key.as_str(),
231 "Successfully inserted tag record"
232 );
233
234 Ok(())
235 }
236}
237
238#[cfg(test)]
242mod tests {
243
244 use super::*;
245 use crate::sql::schema::User;
246 use chrono::{Duration, Utc};
247 use potato_head::{create_score_prompt, create_uuid7};
248 use rand::Rng;
249 use scouter_semver::VersionType;
250 use scouter_settings::ObjectStorageSettings;
251 use scouter_types::llm::PaginationRequest;
252 use scouter_types::psi::{Bin, BinType, PsiDriftConfig, PsiFeatureDriftProfile};
253 use scouter_types::spc::SpcDriftProfile;
254 use scouter_types::sql::TraceFilters;
255 use scouter_types::*;
256 use serde_json::Value;
257 use sqlx::postgres::PgQueryResult;
258 use std::collections::BTreeMap;
259
260 const SPACE: &str = "space";
261 const NAME: &str = "name";
262 const VERSION: &str = "1.0.0";
263 const SCOPE: &str = "scope";
264
265 fn random_trace_record() -> TraceRecord {
266 let mut rng = rand::rng();
267 let random_num = rng.random_range(0..1000);
268 let trace_id: String = (0..32)
269 .map(|_| format!("{:x}", rng.random_range(0..16)))
270 .collect();
271 let span_id: String = (0..16)
272 .map(|_| format!("{:x}", rng.random_range(0..16)))
273 .collect();
274 let created_at = Utc::now() + chrono::Duration::milliseconds(random_num);
275
276 TraceRecord {
277 trace_id: trace_id.clone(),
278 created_at,
279 space: SPACE.to_string(),
280 name: NAME.to_string(),
281 version: VERSION.to_string(),
282 scope: SCOPE.to_string(),
283 trace_state: "running".to_string(),
284 start_time: created_at,
285 end_time: created_at + chrono::Duration::milliseconds(150),
286 duration_ms: 150,
287 status_code: 0,
288 span_count: 1,
289 status_message: "OK".to_string(),
290 root_span_id: span_id.clone(),
291 tags: vec![],
292 }
293 }
294
295 fn random_span_record(trace_id: &str, parent_span_id: Option<&str>) -> TraceSpanRecord {
296 let mut rng = rand::rng();
297 let span_id: String = (0..16)
298 .map(|_| format!("{:x}", rng.random_range(0..16)))
299 .collect();
300
301 let random_offset_ms = rng.random_range(0..1000);
302 let duration_ms_val = rng.random_range(50..500);
303
304 let created_at = Utc::now() + chrono::Duration::milliseconds(random_offset_ms);
305 let start_time = created_at;
306 let end_time = start_time + chrono::Duration::milliseconds(duration_ms_val);
307
308 let status_code = if rng.random_bool(0.95) { 0 } else { 2 };
310 let span_kind_options = ["SERVER", "CLIENT", "INTERNAL", "PRODUCER", "CONSUMER"];
311 let span_kind = span_kind_options[rng.random_range(0..span_kind_options.len())].to_string();
312
313 TraceSpanRecord {
314 created_at,
315 span_id,
316 trace_id: trace_id.to_string(),
317 parent_span_id: parent_span_id.map(|s| s.to_string()),
318 space: SPACE.to_string(),
319 name: NAME.to_string(),
320 version: VERSION.to_string(),
321 scope: SCOPE.to_string(),
322 span_name: format!("{}_{}", "random_operation", rng.random_range(0..10)),
323 span_kind,
324 start_time,
325 end_time,
326 duration_ms: duration_ms_val,
327 status_code,
328 status_message: if status_code == 2 {
329 "Internal Server Error".to_string()
330 } else {
331 "OK".to_string()
332 },
333 attributes: vec![Attribute::default()],
334 events: vec![],
335 links: vec![],
336 label: None,
337 input: Value::default(),
338 output: Value::default(),
339 }
340 }
341
342 pub async fn cleanup(pool: &Pool<Postgres>) {
343 sqlx::raw_sql(
344 r#"
345 DELETE
346 FROM scouter.spc_drift;
347
348 DELETE
349 FROM scouter.observability_metric;
350
351 DELETE
352 FROM scouter.custom_drift;
353
354 DELETE
355 FROM scouter.drift_alert;
356
357 DELETE
358 FROM scouter.drift_profile;
359
360 DELETE
361 FROM scouter.psi_drift;
362
363 DELETE
364 FROM scouter.user;
365
366 DELETE
367 FROM scouter.llm_drift_record;
368
369 DELETE
370 FROM scouter.llm_drift;
371
372 DELETE
373 FROM scouter.spans;
374
375 DELETE
376 FROM scouter.trace_baggage;
377
378 DELETE
379 FROM scouter.traces;
380
381 DELETE
382 FROM scouter.tags;
383 "#,
384 )
385 .fetch_all(pool)
386 .await
387 .unwrap();
388 }
389
390 pub async fn db_pool() -> Pool<Postgres> {
391 let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
392 .await
393 .unwrap();
394
395 cleanup(&pool).await;
396
397 pool
398 }
399
400 pub async fn insert_profile_to_db(
401 pool: &Pool<Postgres>,
402 profile: &DriftProfile,
403 active: bool,
404 deactivate_others: bool,
405 ) -> PgQueryResult {
406 let base_args = profile.get_base_args();
407 let version = PostgresClient::get_next_profile_version(
408 pool,
409 &base_args,
410 VersionType::Minor,
411 None,
412 None,
413 )
414 .await
415 .unwrap();
416
417 let result = PostgresClient::insert_drift_profile(
418 pool,
419 profile,
420 &base_args,
421 &version,
422 &active,
423 &deactivate_others,
424 )
425 .await
426 .unwrap();
427
428 result
429 }
430
431 #[tokio::test]
432 async fn test_postgres() {
433 let _pool = db_pool().await;
434 }
435
436 #[tokio::test]
437 async fn test_postgres_drift_alert() {
438 let pool = db_pool().await;
439
440 let timestamp = Utc::now();
441
442 for _ in 0..10 {
443 let task_info = DriftTaskInfo {
444 space: SPACE.to_string(),
445 name: NAME.to_string(),
446 version: VERSION.to_string(),
447 uid: "test".to_string(),
448 drift_type: DriftType::Spc,
449 };
450
451 let alert = (0..10)
452 .map(|i| (i.to_string(), i.to_string()))
453 .collect::<BTreeMap<String, String>>();
454
455 let result = PostgresClient::insert_drift_alert(
456 &pool,
457 &task_info,
458 "test",
459 &alert,
460 &DriftType::Spc,
461 )
462 .await
463 .unwrap();
464
465 assert_eq!(result.rows_affected(), 1);
466 }
467
468 let alert_request = DriftAlertRequest {
470 space: SPACE.to_string(),
471 name: NAME.to_string(),
472 version: VERSION.to_string(),
473 active: Some(true),
474 limit: None,
475 limit_datetime: None,
476 };
477
478 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
479 .await
480 .unwrap();
481 assert!(alerts.len() > 5);
482
483 let alert_request = DriftAlertRequest {
485 space: SPACE.to_string(),
486 name: NAME.to_string(),
487 version: VERSION.to_string(),
488 active: Some(true),
489 limit: Some(1),
490 limit_datetime: None,
491 };
492
493 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
494 .await
495 .unwrap();
496 assert_eq!(alerts.len(), 1);
497
498 let alert_request = DriftAlertRequest {
500 space: SPACE.to_string(),
501 name: NAME.to_string(),
502 version: VERSION.to_string(),
503 active: Some(true),
504 limit: None,
505 limit_datetime: Some(timestamp),
506 };
507
508 let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
509 .await
510 .unwrap();
511 assert!(alerts.len() > 5);
512 }
513
514 #[tokio::test]
515 async fn test_postgres_spc_drift_record() {
516 let pool = db_pool().await;
517
518 let record1 = SpcServerRecord {
519 created_at: Utc::now(),
520 space: SPACE.to_string(),
521 name: NAME.to_string(),
522 version: VERSION.to_string(),
523 feature: "test".to_string(),
524 value: 1.0,
525 };
526
527 let record2 = SpcServerRecord {
528 created_at: Utc::now(),
529 space: SPACE.to_string(),
530 name: NAME.to_string(),
531 version: VERSION.to_string(),
532 feature: "test2".to_string(),
533 value: 2.0,
534 };
535
536 let result = PostgresClient::insert_spc_drift_records_batch(&pool, &[record1, record2])
537 .await
538 .unwrap();
539
540 assert_eq!(result.rows_affected(), 2);
541 }
542
543 #[tokio::test]
544 async fn test_postgres_bin_count() {
545 let pool = db_pool().await;
546
547 let record1 = PsiServerRecord {
548 created_at: Utc::now(),
549 space: SPACE.to_string(),
550 name: NAME.to_string(),
551 version: VERSION.to_string(),
552 feature: "test".to_string(),
553 bin_id: 1,
554 bin_count: 1,
555 };
556
557 let record2 = PsiServerRecord {
558 created_at: Utc::now(),
559 space: SPACE.to_string(),
560 name: NAME.to_string(),
561 version: VERSION.to_string(),
562 feature: "test2".to_string(),
563 bin_id: 2,
564 bin_count: 2,
565 };
566
567 let result = PostgresClient::insert_bin_counts_batch(&pool, &[record1, record2])
568 .await
569 .unwrap();
570
571 assert_eq!(result.rows_affected(), 2);
572 }
573
574 #[tokio::test]
575 async fn test_postgres_observability_record() {
576 let pool = db_pool().await;
577
578 let record = ObservabilityMetrics::default();
579
580 let result = PostgresClient::insert_observability_record(&pool, &record)
581 .await
582 .unwrap();
583
584 assert_eq!(result.rows_affected(), 1);
585 }
586
587 #[tokio::test]
588 async fn test_postgres_crud_drift_profile() {
589 let pool = db_pool().await;
590
591 let mut spc_profile = SpcDriftProfile::default();
592 let profile = DriftProfile::Spc(spc_profile.clone());
593
594 let result = insert_profile_to_db(&pool, &profile, false, false).await;
595 assert_eq!(result.rows_affected(), 1);
596
597 spc_profile.scouter_version = "test".to_string();
598
599 let result =
600 PostgresClient::update_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
601 .await
602 .unwrap();
603
604 assert_eq!(result.rows_affected(), 1);
605
606 let profile = PostgresClient::get_drift_profile(
607 &pool,
608 &GetProfileRequest {
609 name: spc_profile.config.name.clone(),
610 space: spc_profile.config.space.clone(),
611 version: spc_profile.config.version.clone(),
612 drift_type: DriftType::Spc,
613 },
614 )
615 .await
616 .unwrap();
617
618 let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
619
620 assert_eq!(deserialized, spc_profile);
621
622 PostgresClient::update_drift_profile_status(
623 &pool,
624 &ProfileStatusRequest {
625 name: spc_profile.config.name.clone(),
626 space: spc_profile.config.space.clone(),
627 version: spc_profile.config.version.clone(),
628 active: false,
629 drift_type: Some(DriftType::Spc),
630 deactivate_others: false,
631 },
632 )
633 .await
634 .unwrap();
635 }
636
637 #[tokio::test]
638 async fn test_postgres_get_features() {
639 let pool = db_pool().await;
640
641 let timestamp = Utc::now();
642
643 for _ in 0..10 {
644 let mut records = Vec::new();
645 for j in 0..10 {
646 let record = SpcServerRecord {
647 created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
648 space: SPACE.to_string(),
649 name: NAME.to_string(),
650 version: VERSION.to_string(),
651 feature: format!("test{j}"),
652 value: j as f64,
653 };
654
655 records.push(record);
656 }
657
658 let result = PostgresClient::insert_spc_drift_records_batch(&pool, &records)
659 .await
660 .unwrap();
661 assert_eq!(result.rows_affected(), records.len() as u64);
662 }
663
664 let service_info = ServiceInfo {
665 space: SPACE.to_string(),
666 name: NAME.to_string(),
667 version: VERSION.to_string(),
668 };
669
670 let features = PostgresClient::get_spc_features(&pool, &service_info)
671 .await
672 .unwrap();
673 assert_eq!(features.len(), 10);
674
675 let records =
676 PostgresClient::get_spc_drift_records(&pool, &service_info, ×tamp, &features)
677 .await
678 .unwrap();
679
680 assert_eq!(records.features.len(), 10);
681
682 let binned_records = PostgresClient::get_binned_spc_drift_records(
683 &pool,
684 &DriftRequest {
685 space: SPACE.to_string(),
686 name: NAME.to_string(),
687 version: VERSION.to_string(),
688 time_interval: TimeInterval::FiveMinutes,
689 max_data_points: 10,
690 drift_type: DriftType::Spc,
691 ..Default::default()
692 },
693 &DatabaseSettings::default().retention_period,
694 &ObjectStorageSettings::default(),
695 )
696 .await
697 .unwrap();
698
699 assert_eq!(binned_records.features.len(), 10);
700 }
701
702 #[tokio::test]
703 async fn test_postgres_bin_proportions() {
704 let pool = db_pool().await;
705
706 let timestamp = Utc::now();
707
708 let num_features = 3;
709 let num_bins = 5;
710
711 let features = (0..=num_features)
712 .map(|feature| {
713 let bins = (0..=num_bins)
714 .map(|bind_id| Bin {
715 id: bind_id,
716 lower_limit: None,
717 upper_limit: None,
718 proportion: 0.0,
719 })
720 .collect();
721 let feature_name = format!("feature{feature}");
722 let feature_profile = PsiFeatureDriftProfile {
723 id: feature_name.clone(),
724 bins,
725 timestamp,
726 bin_type: BinType::Numeric,
727 };
728 (feature_name, feature_profile)
729 })
730 .collect();
731
732 let profile = &DriftProfile::Psi(psi::PsiDriftProfile::new(
733 features,
734 PsiDriftConfig {
735 space: SPACE.to_string(),
736 name: NAME.to_string(),
737 version: VERSION.to_string(),
738 ..Default::default()
739 },
740 ));
741 let _ = insert_profile_to_db(&pool, profile, false, false).await;
742
743 for feature in 0..num_features {
744 for bin in 0..=num_bins {
745 let mut records = Vec::new();
746 for j in 0..=100 {
747 let record = PsiServerRecord {
748 created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
749 space: SPACE.to_string(),
750 name: NAME.to_string(),
751 version: VERSION.to_string(),
752 feature: format!("feature{feature}"),
753 bin_id: bin,
754 bin_count: rand::rng().random_range(0..10),
755 };
756
757 records.push(record);
758 }
759 PostgresClient::insert_bin_counts_batch(&pool, &records)
760 .await
761 .unwrap();
762 }
763 }
764
765 let binned_records = PostgresClient::get_feature_distributions(
766 &pool,
767 &ServiceInfo {
768 space: SPACE.to_string(),
769 name: NAME.to_string(),
770 version: VERSION.to_string(),
771 },
772 ×tamp,
773 &["feature0".to_string()],
774 )
775 .await
776 .unwrap();
777
778 let bin_proportion = binned_records
780 .distributions
781 .get("feature0")
782 .unwrap()
783 .bins
784 .get(&1)
785 .unwrap();
786
787 assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
788
789 let binned_records = PostgresClient::get_binned_psi_drift_records(
790 &pool,
791 &DriftRequest {
792 space: SPACE.to_string(),
793 name: NAME.to_string(),
794 version: VERSION.to_string(),
795 time_interval: TimeInterval::OneHour,
796 max_data_points: 1000,
797 drift_type: DriftType::Psi,
798 ..Default::default()
799 },
800 &DatabaseSettings::default().retention_period,
801 &ObjectStorageSettings::default(),
802 )
803 .await
804 .unwrap();
805 assert_eq!(binned_records.len(), 3);
807 }
808
809 #[tokio::test]
810 async fn test_postgres_cru_custom_metric() {
811 let pool = db_pool().await;
812
813 let timestamp = Utc::now();
814
815 for i in 0..2 {
816 let mut records = Vec::new();
817 for j in 0..25 {
818 let record = CustomMetricServerRecord {
819 created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
820 space: SPACE.to_string(),
821 name: NAME.to_string(),
822 version: VERSION.to_string(),
823 metric: format!("metric{i}"),
824 value: rand::rng().random_range(0..10) as f64,
825 };
826 records.push(record);
827 }
828 let result = PostgresClient::insert_custom_metric_values_batch(&pool, &records)
829 .await
830 .unwrap();
831 assert_eq!(result.rows_affected(), 25);
832 }
833
834 let record = CustomMetricServerRecord {
836 created_at: Utc::now(),
837 space: SPACE.to_string(),
838 name: NAME.to_string(),
839 version: VERSION.to_string(),
840 metric: "metric3".to_string(),
841 value: rand::rng().random_range(0..10) as f64,
842 };
843
844 let result = PostgresClient::insert_custom_metric_values_batch(&pool, &[record])
845 .await
846 .unwrap();
847 assert_eq!(result.rows_affected(), 1);
848
849 let metrics = PostgresClient::get_custom_metric_values(
850 &pool,
851 &ServiceInfo {
852 space: SPACE.to_string(),
853 name: NAME.to_string(),
854 version: VERSION.to_string(),
855 },
856 ×tamp,
857 &["metric1".to_string()],
858 )
859 .await
860 .unwrap();
861
862 assert_eq!(metrics.len(), 1);
863
864 let binned_records = PostgresClient::get_binned_custom_drift_records(
865 &pool,
866 &DriftRequest {
867 space: SPACE.to_string(),
868 name: NAME.to_string(),
869 version: VERSION.to_string(),
870 time_interval: TimeInterval::OneHour,
871 max_data_points: 1000,
872 drift_type: DriftType::Custom,
873 ..Default::default()
874 },
875 &DatabaseSettings::default().retention_period,
876 &ObjectStorageSettings::default(),
877 )
878 .await
879 .unwrap();
880 assert_eq!(binned_records.metrics.len(), 3);
882 }
883
884 #[tokio::test]
885 async fn test_postgres_user() {
886 let pool = db_pool().await;
887 let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
888
889 let user = User::new(
891 "user".to_string(),
892 "pass".to_string(),
893 "email".to_string(),
894 recovery_codes,
895 None,
896 None,
897 None,
898 None,
899 );
900 PostgresClient::insert_user(&pool, &user).await.unwrap();
901
902 let mut user = PostgresClient::get_user(&pool, "user")
904 .await
905 .unwrap()
906 .unwrap();
907
908 assert_eq!(user.username, "user");
909 assert_eq!(user.group_permissions, vec!["user"]);
910 assert_eq!(user.email, "email");
911
912 user.active = false;
914 user.refresh_token = Some("token".to_string());
915
916 PostgresClient::update_user(&pool, &user).await.unwrap();
918 let user = PostgresClient::get_user(&pool, "user")
919 .await
920 .unwrap()
921 .unwrap();
922 assert!(!user.active);
923 assert_eq!(user.refresh_token.unwrap(), "token");
924
925 let users = PostgresClient::get_users(&pool).await.unwrap();
927 assert_eq!(users.len(), 1);
928
929 let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
931 assert!(!is_last_admin);
932
933 PostgresClient::delete_user(&pool, "user").await.unwrap();
935 }
936
937 #[tokio::test]
938 async fn test_postgres_llm_drift_record_insert_get() {
939 let pool = db_pool().await;
940
941 let input = "This is a test input";
942 let output = "This is a test response";
943 let prompt = create_score_prompt(None);
944
945 for j in 0..10 {
946 let context = serde_json::json!({
947 "input": input,
948 "response": output,
949 });
950 let record = LLMDriftServerRecord {
951 created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
952 space: SPACE.to_string(),
953 name: NAME.to_string(),
954 version: VERSION.to_string(),
955 prompt: Some(prompt.model_dump_value()),
956 context,
957 status: Status::Pending,
958 id: 0, uid: "test".to_string(),
960 updated_at: None,
961 score: Value::Null,
962 processing_started_at: None,
963 processing_ended_at: None,
964 processing_duration: None,
965 };
966
967 let result = PostgresClient::insert_llm_drift_record(&pool, &record)
968 .await
969 .unwrap();
970
971 assert_eq!(result.rows_affected(), 1);
972 }
973
974 let service_info = ServiceInfo {
975 space: SPACE.to_string(),
976 name: NAME.to_string(),
977 version: VERSION.to_string(),
978 };
979
980 let features = PostgresClient::get_llm_drift_records(&pool, &service_info, None, None)
981 .await
982 .unwrap();
983 assert_eq!(features.len(), 10);
984
985 let pending_tasks = PostgresClient::get_pending_llm_drift_record(&pool)
987 .await
988 .unwrap();
989
990 assert!(pending_tasks.is_some());
992
993 let task_input = &pending_tasks.as_ref().unwrap().context["input"];
995 assert_eq!(*task_input, "This is a test input".to_string());
996
997 PostgresClient::update_llm_drift_record_status(
999 &pool,
1000 &pending_tasks.unwrap(),
1001 Status::Processed,
1002 Some(1),
1003 )
1004 .await
1005 .unwrap();
1006
1007 let processed_tasks = PostgresClient::get_llm_drift_records(
1009 &pool,
1010 &service_info,
1011 None,
1012 Some(Status::Processed),
1013 )
1014 .await
1015 .unwrap();
1016
1017 assert_eq!(processed_tasks.len(), 1);
1019 }
1020
1021 #[tokio::test]
1022 async fn test_postgres_llm_drift_record_pagination() {
1023 let pool = db_pool().await;
1024
1025 let input = "This is a test input";
1026 let output = "This is a test response";
1027 let prompt = create_score_prompt(None);
1028
1029 for j in 0..10 {
1030 let context = serde_json::json!({
1031 "input": input,
1032 "response": output,
1033 });
1034 let record = LLMDriftServerRecord {
1035 created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1036 space: SPACE.to_string(),
1037 name: NAME.to_string(),
1038 version: VERSION.to_string(),
1039 prompt: Some(prompt.model_dump_value()),
1040 context,
1041 score: Value::Null,
1042 status: Status::Pending,
1043 id: 0, uid: "test".to_string(),
1045 updated_at: None,
1046 processing_started_at: None,
1047 processing_ended_at: None,
1048 processing_duration: None,
1049 };
1050
1051 let result = PostgresClient::insert_llm_drift_record(&pool, &record)
1052 .await
1053 .unwrap();
1054
1055 assert_eq!(result.rows_affected(), 1);
1056 }
1057
1058 let service_info = ServiceInfo {
1059 space: SPACE.to_string(),
1060 name: NAME.to_string(),
1061 version: VERSION.to_string(),
1062 };
1063
1064 let pagination = PaginationRequest {
1066 limit: 5,
1067 cursor: None, };
1069
1070 let paginated_features = PostgresClient::get_llm_drift_records_pagination(
1071 &pool,
1072 &service_info,
1073 None,
1074 pagination,
1075 )
1076 .await
1077 .unwrap();
1078
1079 assert_eq!(paginated_features.items.len(), 5);
1080 assert!(paginated_features.next_cursor.is_some());
1081
1082 let last_record = paginated_features.items.first().unwrap();
1084
1085 let next_cursor = paginated_features.next_cursor.unwrap();
1087 let pagination = PaginationRequest {
1088 limit: 5,
1089 cursor: Some(next_cursor),
1090 };
1091
1092 let paginated_features = PostgresClient::get_llm_drift_records_pagination(
1093 &pool,
1094 &service_info,
1095 None,
1096 pagination,
1097 )
1098 .await
1099 .unwrap();
1100
1101 assert_eq!(paginated_features.items.len(), 5);
1102 assert!(paginated_features.next_cursor.is_none());
1103
1104 let first_record = paginated_features.items.last().unwrap();
1106
1107 let diff = last_record.id - first_record.id + 1; assert!(diff == 10);
1109 }
1110
1111 #[tokio::test]
1112 async fn test_postgres_llm_metrics_insert_get() {
1113 let pool = db_pool().await;
1114
1115 let timestamp = Utc::now();
1116
1117 for i in 0..2 {
1118 let mut records = Vec::new();
1119 for j in 0..25 {
1120 let record = LLMMetricRecord {
1121 record_uid: format!("uid{i}{j}"),
1122 created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1123 space: SPACE.to_string(),
1124 name: NAME.to_string(),
1125 version: VERSION.to_string(),
1126 metric: format!("metric{i}"),
1127 value: rand::rng().random_range(0..10) as f64,
1128 };
1129 records.push(record);
1130 }
1131 let result = PostgresClient::insert_llm_metric_values_batch(&pool, &records)
1132 .await
1133 .unwrap();
1134 assert_eq!(result.rows_affected(), 25);
1135 }
1136
1137 let metrics = PostgresClient::get_llm_metric_values(
1138 &pool,
1139 &ServiceInfo {
1140 space: SPACE.to_string(),
1141 name: NAME.to_string(),
1142 version: VERSION.to_string(),
1143 },
1144 ×tamp,
1145 &["metric1".to_string()],
1146 )
1147 .await
1148 .unwrap();
1149
1150 assert_eq!(metrics.len(), 1);
1151 let binned_records = PostgresClient::get_binned_llm_metric_values(
1152 &pool,
1153 &DriftRequest {
1154 space: SPACE.to_string(),
1155 name: NAME.to_string(),
1156 version: VERSION.to_string(),
1157 time_interval: TimeInterval::OneHour,
1158 max_data_points: 1000,
1159 drift_type: DriftType::LLM,
1160 ..Default::default()
1161 },
1162 &DatabaseSettings::default().retention_period,
1163 &ObjectStorageSettings::default(),
1164 )
1165 .await
1166 .unwrap();
1167 assert_eq!(binned_records.metrics.len(), 2);
1169 }
1170
1171 #[tokio::test]
1172 async fn test_postgres_tracing() {
1173 let pool = db_pool().await;
1174 let script = std::fs::read_to_string("src/tests/script/populate_trace.sql").unwrap();
1175 sqlx::query(&script).execute(&pool).await.unwrap();
1176 let mut filters = TraceFilters::default();
1177
1178 let first_batch = PostgresClient::get_traces_paginated(&pool, filters.clone())
1179 .await
1180 .unwrap();
1181
1182 assert_eq!(
1183 first_batch.items.len(),
1184 50,
1185 "First batch should have 50 records"
1186 );
1187
1188 let last_record = first_batch.next_cursor.unwrap();
1190 filters = filters.next_page(&last_record);
1191
1192 let next_batch = PostgresClient::get_traces_paginated(&pool, filters.clone())
1193 .await
1194 .unwrap();
1195
1196 assert_eq!(
1198 next_batch.items.len(),
1199 50,
1200 "Next batch should have 50 records"
1201 );
1202
1203 let next_first_record = next_batch.items.first().unwrap();
1205 assert!(
1206 next_first_record.created_at <= last_record.created_at,
1207 "Next batch first record timestamp is not less than or equal to last record timestamp"
1208 );
1209
1210 filters = filters.previous_page(&next_batch.previous_cursor.unwrap());
1212 let previous_batch = PostgresClient::get_traces_paginated(&pool, filters.clone())
1213 .await
1214 .unwrap();
1215 assert_eq!(
1216 previous_batch.items.len(),
1217 50,
1218 "Previous batch should have 50 records"
1219 );
1220
1221 let filtered_record = first_batch
1223 .items
1224 .iter()
1225 .find(|record| record.span_count > Some(5))
1226 .unwrap();
1227
1228 filters.cursor_created_at = None;
1229 filters.cursor_trace_id = None;
1230 filters.space = Some(filtered_record.space.clone());
1231 filters.name = Some(filtered_record.name.clone());
1232 filters.version = Some(filtered_record.version.clone());
1233
1234 let records = PostgresClient::get_traces_paginated(&pool, filters.clone())
1235 .await
1236 .unwrap();
1237
1238 assert!(
1240 !records.items.is_empty(),
1241 "Should return records with specified filters"
1242 );
1243
1244 let spans = PostgresClient::get_trace_spans(&pool, &filtered_record.trace_id)
1246 .await
1247 .unwrap();
1248
1249 assert!(spans.len() == filtered_record.span_count.unwrap() as usize);
1250
1251 let start_time = filtered_record.created_at - chrono::Duration::hours(24);
1252 let end_time = filtered_record.created_at + chrono::Duration::minutes(5);
1253
1254 let trace_metrics = PostgresClient::get_trace_metrics(
1256 &pool,
1257 None,
1258 None,
1259 None,
1260 start_time,
1261 end_time,
1262 "60 minutes",
1263 )
1264 .await
1265 .unwrap();
1266
1267 assert!(trace_metrics.len() >= 10);
1269 }
1270
1271 #[tokio::test]
1272 async fn test_postgres_tracing_insert() {
1273 let pool = db_pool().await;
1274
1275 let mut trace_record = random_trace_record();
1277 let trace_id = trace_record.trace_id.clone();
1278
1279 let root_span = random_span_record(&trace_id, None);
1281 let child_span = random_span_record(&trace_id, Some(&root_span.span_id));
1282
1283 trace_record.root_span_id = root_span.span_id.clone();
1285
1286 let result = PostgresClient::upsert_trace_batch(&pool, &[trace_record.clone()])
1288 .await
1289 .unwrap();
1290
1291 assert_eq!(result.rows_affected(), 1);
1292
1293 let result = PostgresClient::upsert_trace_batch(&pool, &[trace_record.clone()])
1295 .await
1296 .unwrap();
1297
1298 assert_eq!(result.rows_affected(), 1);
1299
1300 let result =
1302 PostgresClient::insert_span_batch(&pool, &[root_span.clone(), child_span.clone()])
1303 .await
1304 .unwrap();
1305
1306 assert_eq!(result.rows_affected(), 2);
1307
1308 sqlx::query("REFRESH MATERIALIZED VIEW scouter.trace_summary;")
1310 .execute(&pool)
1311 .await
1312 .unwrap();
1313
1314 let inserted_created_at = trace_record.created_at;
1315 let inserted_trace_id = trace_record.trace_id.clone();
1316
1317 let trace_filter = TraceFilters {
1318 cursor_created_at: Some(inserted_created_at + Duration::days(1)),
1319 cursor_trace_id: Some(inserted_trace_id),
1320 start_time: Some(inserted_created_at - Duration::minutes(5)),
1321 end_time: Some(inserted_created_at + Duration::days(1)),
1322 ..TraceFilters::default()
1323 };
1324
1325 let traces = PostgresClient::get_traces_paginated(&pool, trace_filter)
1326 .await
1327 .unwrap();
1328
1329 assert_eq!(traces.items.len(), 1);
1330 let retrieved_trace = &traces.items[0];
1331 assert_eq!(retrieved_trace.span_count.unwrap(), 2);
1333
1334 let baggage = TraceBaggageRecord {
1335 created_at: Utc::now(),
1336 trace_id: trace_record.trace_id.clone(),
1337 scope: "test_scope".to_string(),
1338 key: "user_id".to_string(),
1339 value: "12345".to_string(),
1340 };
1341
1342 let result =
1343 PostgresClient::insert_trace_baggage_batch(&pool, std::slice::from_ref(&baggage))
1344 .await
1345 .unwrap();
1346
1347 assert_eq!(result.rows_affected(), 1);
1348
1349 let retrieved_baggage =
1350 PostgresClient::get_trace_baggage_records(&pool, &trace_record.trace_id)
1351 .await
1352 .unwrap();
1353
1354 assert_eq!(retrieved_baggage.len(), 1);
1355 }
1356
1357 #[tokio::test]
1358 async fn test_postgres_tags() {
1359 let pool = db_pool().await;
1360 let uid = create_uuid7();
1361
1362 let tag1 = TagRecord {
1363 created_at: Utc::now(),
1364 entity_id: uid.clone(),
1365 entity_type: "service".to_string(),
1366 key: "env".to_string(),
1367 value: "production".to_string(),
1368 };
1369
1370 let tag2 = TagRecord {
1371 created_at: Utc::now(),
1372 entity_id: uid.clone(),
1373 entity_type: "service".to_string(),
1374 key: "team".to_string(),
1375 value: "backend".to_string(),
1376 };
1377
1378 let result = PostgresClient::insert_tag_batch(&pool, &[tag1.clone(), tag2.clone()])
1379 .await
1380 .unwrap();
1381
1382 assert_eq!(result.rows_affected(), 2);
1383
1384 let tags = PostgresClient::get_tags(&pool, "service", &uid)
1385 .await
1386 .unwrap();
1387
1388 assert_eq!(tags.len(), 2);
1389 }
1390}