scouter_sql/sql/
postgres.rs

1use crate::sql::error::SqlError;
2use crate::sql::traits::{
3    AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, LLMDriftSqlLogic, ObservabilitySqlLogic,
4    ProfileSqlLogic, PsiSqlLogic, SpcSqlLogic, TagSqlLogic, TraceSqlLogic, UserSqlLogic,
5};
6use scouter_settings::DatabaseSettings;
7use scouter_types::{RecordType, ServerRecords, TagRecord, ToDriftRecords, TraceServerRecord};
8use sqlx::ConnectOptions;
9use sqlx::{postgres::PgConnectOptions, Pool, Postgres};
10use std::result::Result::Ok;
11use tokio::try_join;
12use tracing::{debug, error, info, instrument};
13
14#[derive(Debug, Clone)]
15#[allow(dead_code)]
16pub struct PostgresClient {}
17
18impl SpcSqlLogic for PostgresClient {}
19impl CustomMetricSqlLogic for PostgresClient {}
20impl PsiSqlLogic for PostgresClient {}
21impl LLMDriftSqlLogic for PostgresClient {}
22impl UserSqlLogic for PostgresClient {}
23impl ProfileSqlLogic for PostgresClient {}
24impl ObservabilitySqlLogic for PostgresClient {}
25impl AlertSqlLogic for PostgresClient {}
26impl ArchiveSqlLogic for PostgresClient {}
27impl TraceSqlLogic for PostgresClient {}
28impl TagSqlLogic for PostgresClient {}
29
30impl PostgresClient {
31    /// Setup the application with the given database pool.
32    ///
33    /// # Returns
34    ///
35    /// * `Result<Pool<Postgres>, anyhow::Error>` - Result of the database pool
36    #[instrument(skip(database_settings))]
37    pub async fn create_db_pool(
38        database_settings: &DatabaseSettings,
39    ) -> Result<Pool<Postgres>, SqlError> {
40        let mut opts: PgConnectOptions = database_settings.connection_uri.parse()?;
41
42        // Sqlx logs a lot of debug information by default, which can be overwhelming.
43
44        opts = opts.log_statements(tracing::log::LevelFilter::Off);
45
46        let pool = match sqlx::postgres::PgPoolOptions::new()
47            .max_connections(database_settings.max_connections)
48            .connect_with(opts)
49            .await
50        {
51            Ok(pool) => {
52                info!("✅ Successfully connected to database");
53                pool
54            }
55            Err(err) => {
56                error!("🚨 Failed to connect to database {:?}", err);
57                std::process::exit(1);
58            }
59        };
60
61        // Run migrations
62        if let Err(err) = Self::run_migrations(&pool).await {
63            error!("🚨 Failed to run migrations {:?}", err);
64            std::process::exit(1);
65        }
66
67        Ok(pool)
68    }
69
70    pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
71        info!("Running migrations");
72        sqlx::migrate!("src/migrations")
73            .run(pool)
74            .await
75            .map_err(SqlError::MigrateError)?;
76
77        debug!("Migrations complete");
78
79        Ok(())
80    }
81}
82
83pub struct MessageHandler {}
84
85impl MessageHandler {
86    const DEFAULT_BATCH_SIZE: usize = 500;
87    #[instrument(skip_all)]
88    pub async fn insert_server_records(
89        pool: &Pool<Postgres>,
90        records: &ServerRecords,
91    ) -> Result<(), SqlError> {
92        debug!("Inserting server records: {:?}", records.record_type()?);
93
94        match records.record_type()? {
95            RecordType::Spc => {
96                let spc_records = records.to_spc_drift_records()?;
97                debug!("SPC record count: {}", spc_records.len());
98
99                for chunk in spc_records.chunks(Self::DEFAULT_BATCH_SIZE) {
100                    PostgresClient::insert_spc_drift_records_batch(pool, chunk)
101                        .await
102                        .map_err(|e| {
103                            error!("Failed to insert SPC drift records batch: {:?}", e);
104                            e
105                        })?;
106                }
107            }
108
109            RecordType::Psi => {
110                let psi_records = records.to_psi_drift_records()?;
111                debug!("PSI record count: {}", psi_records.len());
112
113                for chunk in psi_records.chunks(Self::DEFAULT_BATCH_SIZE) {
114                    PostgresClient::insert_bin_counts_batch(pool, chunk)
115                        .await
116                        .map_err(|e| {
117                            error!("Failed to insert PSI drift records batch: {:?}", e);
118                            e
119                        })?;
120                }
121            }
122            RecordType::Custom => {
123                let custom_records = records.to_custom_metric_drift_records()?;
124                debug!("Custom record count: {}", custom_records.len());
125
126                for chunk in custom_records.chunks(Self::DEFAULT_BATCH_SIZE) {
127                    PostgresClient::insert_custom_metric_values_batch(pool, chunk)
128                        .await
129                        .map_err(|e| {
130                            error!("Failed to insert custom metric records batch: {:?}", e);
131                            e
132                        })?;
133                }
134            }
135
136            RecordType::LLMDrift => {
137                debug!("LLM Drift record count: {:?}", records.len());
138                let records = records.to_llm_drift_records()?;
139                for record in records.iter() {
140                    let _ = PostgresClient::insert_llm_drift_record(pool, record)
141                        .await
142                        .map_err(|e| {
143                            error!("Failed to insert LLM drift record: {:?}", e);
144                        });
145                }
146            }
147
148            RecordType::LLMMetric => {
149                debug!("LLM Metric record count: {:?}", records.len());
150                let llm_metric_records = records.to_llm_metric_records()?;
151
152                for chunk in llm_metric_records.chunks(Self::DEFAULT_BATCH_SIZE) {
153                    PostgresClient::insert_llm_metric_values_batch(pool, chunk)
154                        .await
155                        .map_err(|e| {
156                            error!("Failed to insert LLM metric records batch: {:?}", e);
157                            e
158                        })?;
159                }
160            }
161
162            _ => {
163                error!(
164                    "Unsupported record type for batch insert: {:?}",
165                    records.record_type()?
166                );
167                return Err(SqlError::UnsupportedBatchTypeError);
168            }
169        }
170
171        Ok(())
172    }
173
174    pub async fn insert_trace_server_record(
175        pool: &Pool<Postgres>,
176        records: &TraceServerRecord,
177    ) -> Result<(), SqlError> {
178        let (trace_batch, span_batch, baggage_batch) = records.to_records()?;
179
180        let all_tags: Vec<TagRecord> = trace_batch
181            .iter()
182            .flat_map(|trace| {
183                trace.tags.iter().map(|tag| TagRecord {
184                    created_at: trace.created_at,
185                    entity_type: "trace".to_string(),
186                    entity_id: trace.trace_id.clone(),
187                    key: tag.key.clone(),
188                    value: tag.value.clone(),
189                })
190            })
191            .collect();
192
193        let (trace_result, span_result, baggage_result, tag_result) = try_join!(
194            PostgresClient::upsert_trace_batch(pool, &trace_batch),
195            PostgresClient::insert_span_batch(pool, &span_batch),
196            PostgresClient::insert_trace_baggage_batch(pool, &baggage_batch),
197            async {
198                if !all_tags.is_empty() {
199                    PostgresClient::insert_tag_batch(pool, &all_tags).await
200                } else {
201                    Ok(sqlx::postgres::PgQueryResult::default())
202                }
203            }
204        )?;
205
206        debug!(
207            trace_rows = trace_result.rows_affected(),
208            span_rows = span_result.rows_affected(),
209            baggage_rows = baggage_result.rows_affected(),
210            tag_rows = tag_result.rows_affected(),
211            total_traces = trace_batch.len(),
212            total_spans = span_batch.len(),
213            total_baggage = baggage_batch.len(),
214            total_tags = all_tags.len(),
215            "Successfully inserted trace server records"
216        );
217        Ok(())
218    }
219
220    pub async fn insert_tag_record(
221        pool: &Pool<Postgres>,
222        record: &TagRecord,
223    ) -> Result<(), SqlError> {
224        let result = PostgresClient::insert_tag_batch(pool, std::slice::from_ref(record)).await?;
225
226        debug!(
227            rows_affected = result.rows_affected(),
228            entity_type = record.entity_type.as_str(),
229            entity_id = record.entity_id.as_str(),
230            key = record.key.as_str(),
231            "Successfully inserted tag record"
232        );
233
234        Ok(())
235    }
236}
237
238/// Runs database integratino tests
239/// Note - binned queries targeting custom intervals with long-term and short-term data are
240/// done in the scouter-server integration tests
241#[cfg(test)]
242mod tests {
243
244    use super::*;
245    use crate::sql::schema::User;
246    use chrono::{Duration, Utc};
247    use potato_head::{create_score_prompt, create_uuid7};
248    use rand::Rng;
249    use scouter_semver::VersionType;
250    use scouter_settings::ObjectStorageSettings;
251    use scouter_types::llm::PaginationRequest;
252    use scouter_types::psi::{Bin, BinType, PsiDriftConfig, PsiFeatureDriftProfile};
253    use scouter_types::spc::SpcDriftProfile;
254    use scouter_types::sql::TraceFilters;
255    use scouter_types::*;
256    use serde_json::Value;
257    use sqlx::postgres::PgQueryResult;
258    use std::collections::BTreeMap;
259
260    const SPACE: &str = "space";
261    const NAME: &str = "name";
262    const VERSION: &str = "1.0.0";
263    const SCOPE: &str = "scope";
264
265    fn random_trace_record() -> TraceRecord {
266        let mut rng = rand::rng();
267        let random_num = rng.random_range(0..1000);
268        let trace_id: String = (0..32)
269            .map(|_| format!("{:x}", rng.random_range(0..16)))
270            .collect();
271        let span_id: String = (0..16)
272            .map(|_| format!("{:x}", rng.random_range(0..16)))
273            .collect();
274        let created_at = Utc::now() + chrono::Duration::milliseconds(random_num);
275
276        TraceRecord {
277            trace_id: trace_id.clone(),
278            created_at,
279            space: SPACE.to_string(),
280            name: NAME.to_string(),
281            version: VERSION.to_string(),
282            scope: SCOPE.to_string(),
283            trace_state: "running".to_string(),
284            start_time: created_at,
285            end_time: created_at + chrono::Duration::milliseconds(150),
286            duration_ms: 150,
287            status_code: 0,
288            span_count: 1,
289            status_message: "OK".to_string(),
290            root_span_id: span_id.clone(),
291            tags: vec![],
292        }
293    }
294
295    fn random_span_record(trace_id: &str, parent_span_id: Option<&str>) -> TraceSpanRecord {
296        let mut rng = rand::rng();
297        let span_id: String = (0..16)
298            .map(|_| format!("{:x}", rng.random_range(0..16)))
299            .collect();
300
301        let random_offset_ms = rng.random_range(0..1000);
302        let duration_ms_val = rng.random_range(50..500);
303
304        let created_at = Utc::now() + chrono::Duration::milliseconds(random_offset_ms);
305        let start_time = created_at;
306        let end_time = start_time + chrono::Duration::milliseconds(duration_ms_val);
307
308        // --- Status and Kind ---
309        let status_code = if rng.random_bool(0.95) { 0 } else { 2 };
310        let span_kind_options = ["SERVER", "CLIENT", "INTERNAL", "PRODUCER", "CONSUMER"];
311        let span_kind = span_kind_options[rng.random_range(0..span_kind_options.len())].to_string();
312
313        TraceSpanRecord {
314            created_at,
315            span_id,
316            trace_id: trace_id.to_string(),
317            parent_span_id: parent_span_id.map(|s| s.to_string()),
318            space: SPACE.to_string(),
319            name: NAME.to_string(),
320            version: VERSION.to_string(),
321            scope: SCOPE.to_string(),
322            span_name: format!("{}_{}", "random_operation", rng.random_range(0..10)),
323            span_kind,
324            start_time,
325            end_time,
326            duration_ms: duration_ms_val,
327            status_code,
328            status_message: if status_code == 2 {
329                "Internal Server Error".to_string()
330            } else {
331                "OK".to_string()
332            },
333            attributes: vec![Attribute::default()],
334            events: vec![],
335            links: vec![],
336            label: None,
337            input: Value::default(),
338            output: Value::default(),
339        }
340    }
341
342    pub async fn cleanup(pool: &Pool<Postgres>) {
343        sqlx::raw_sql(
344            r#"
345            DELETE
346            FROM scouter.spc_drift;
347
348            DELETE
349            FROM scouter.observability_metric;
350
351            DELETE
352            FROM scouter.custom_drift;
353
354            DELETE
355            FROM scouter.drift_alert;
356
357            DELETE
358            FROM scouter.drift_profile;
359
360            DELETE
361            FROM scouter.psi_drift;
362
363            DELETE
364            FROM scouter.user;
365
366            DELETE
367            FROM scouter.llm_drift_record;
368
369            DELETE
370            FROM scouter.llm_drift;
371
372            DELETE
373            FROM scouter.spans;
374
375            DELETE
376            FROM scouter.trace_baggage;
377
378            DELETE
379            FROM scouter.traces;
380
381            DELETE
382            FROM scouter.tags;
383            "#,
384        )
385        .fetch_all(pool)
386        .await
387        .unwrap();
388    }
389
390    pub async fn db_pool() -> Pool<Postgres> {
391        let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
392            .await
393            .unwrap();
394
395        cleanup(&pool).await;
396
397        pool
398    }
399
400    pub async fn insert_profile_to_db(
401        pool: &Pool<Postgres>,
402        profile: &DriftProfile,
403        active: bool,
404        deactivate_others: bool,
405    ) -> PgQueryResult {
406        let base_args = profile.get_base_args();
407        let version = PostgresClient::get_next_profile_version(
408            pool,
409            &base_args,
410            VersionType::Minor,
411            None,
412            None,
413        )
414        .await
415        .unwrap();
416
417        let result = PostgresClient::insert_drift_profile(
418            pool,
419            profile,
420            &base_args,
421            &version,
422            &active,
423            &deactivate_others,
424        )
425        .await
426        .unwrap();
427
428        result
429    }
430
431    #[tokio::test]
432    async fn test_postgres() {
433        let _pool = db_pool().await;
434    }
435
436    #[tokio::test]
437    async fn test_postgres_drift_alert() {
438        let pool = db_pool().await;
439
440        let timestamp = Utc::now();
441
442        for _ in 0..10 {
443            let task_info = DriftTaskInfo {
444                space: SPACE.to_string(),
445                name: NAME.to_string(),
446                version: VERSION.to_string(),
447                uid: "test".to_string(),
448                drift_type: DriftType::Spc,
449            };
450
451            let alert = (0..10)
452                .map(|i| (i.to_string(), i.to_string()))
453                .collect::<BTreeMap<String, String>>();
454
455            let result = PostgresClient::insert_drift_alert(
456                &pool,
457                &task_info,
458                "test",
459                &alert,
460                &DriftType::Spc,
461            )
462            .await
463            .unwrap();
464
465            assert_eq!(result.rows_affected(), 1);
466        }
467
468        // get alerts
469        let alert_request = DriftAlertRequest {
470            space: SPACE.to_string(),
471            name: NAME.to_string(),
472            version: VERSION.to_string(),
473            active: Some(true),
474            limit: None,
475            limit_datetime: None,
476        };
477
478        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
479            .await
480            .unwrap();
481        assert!(alerts.len() > 5);
482
483        // get alerts limit 1
484        let alert_request = DriftAlertRequest {
485            space: SPACE.to_string(),
486            name: NAME.to_string(),
487            version: VERSION.to_string(),
488            active: Some(true),
489            limit: Some(1),
490            limit_datetime: None,
491        };
492
493        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
494            .await
495            .unwrap();
496        assert_eq!(alerts.len(), 1);
497
498        // get alerts limit timestamp
499        let alert_request = DriftAlertRequest {
500            space: SPACE.to_string(),
501            name: NAME.to_string(),
502            version: VERSION.to_string(),
503            active: Some(true),
504            limit: None,
505            limit_datetime: Some(timestamp),
506        };
507
508        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
509            .await
510            .unwrap();
511        assert!(alerts.len() > 5);
512    }
513
514    #[tokio::test]
515    async fn test_postgres_spc_drift_record() {
516        let pool = db_pool().await;
517
518        let record1 = SpcServerRecord {
519            created_at: Utc::now(),
520            space: SPACE.to_string(),
521            name: NAME.to_string(),
522            version: VERSION.to_string(),
523            feature: "test".to_string(),
524            value: 1.0,
525        };
526
527        let record2 = SpcServerRecord {
528            created_at: Utc::now(),
529            space: SPACE.to_string(),
530            name: NAME.to_string(),
531            version: VERSION.to_string(),
532            feature: "test2".to_string(),
533            value: 2.0,
534        };
535
536        let result = PostgresClient::insert_spc_drift_records_batch(&pool, &[record1, record2])
537            .await
538            .unwrap();
539
540        assert_eq!(result.rows_affected(), 2);
541    }
542
543    #[tokio::test]
544    async fn test_postgres_bin_count() {
545        let pool = db_pool().await;
546
547        let record1 = PsiServerRecord {
548            created_at: Utc::now(),
549            space: SPACE.to_string(),
550            name: NAME.to_string(),
551            version: VERSION.to_string(),
552            feature: "test".to_string(),
553            bin_id: 1,
554            bin_count: 1,
555        };
556
557        let record2 = PsiServerRecord {
558            created_at: Utc::now(),
559            space: SPACE.to_string(),
560            name: NAME.to_string(),
561            version: VERSION.to_string(),
562            feature: "test2".to_string(),
563            bin_id: 2,
564            bin_count: 2,
565        };
566
567        let result = PostgresClient::insert_bin_counts_batch(&pool, &[record1, record2])
568            .await
569            .unwrap();
570
571        assert_eq!(result.rows_affected(), 2);
572    }
573
574    #[tokio::test]
575    async fn test_postgres_observability_record() {
576        let pool = db_pool().await;
577
578        let record = ObservabilityMetrics::default();
579
580        let result = PostgresClient::insert_observability_record(&pool, &record)
581            .await
582            .unwrap();
583
584        assert_eq!(result.rows_affected(), 1);
585    }
586
587    #[tokio::test]
588    async fn test_postgres_crud_drift_profile() {
589        let pool = db_pool().await;
590
591        let mut spc_profile = SpcDriftProfile::default();
592        let profile = DriftProfile::Spc(spc_profile.clone());
593
594        let result = insert_profile_to_db(&pool, &profile, false, false).await;
595        assert_eq!(result.rows_affected(), 1);
596
597        spc_profile.scouter_version = "test".to_string();
598
599        let result =
600            PostgresClient::update_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
601                .await
602                .unwrap();
603
604        assert_eq!(result.rows_affected(), 1);
605
606        let profile = PostgresClient::get_drift_profile(
607            &pool,
608            &GetProfileRequest {
609                name: spc_profile.config.name.clone(),
610                space: spc_profile.config.space.clone(),
611                version: spc_profile.config.version.clone(),
612                drift_type: DriftType::Spc,
613            },
614        )
615        .await
616        .unwrap();
617
618        let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
619
620        assert_eq!(deserialized, spc_profile);
621
622        PostgresClient::update_drift_profile_status(
623            &pool,
624            &ProfileStatusRequest {
625                name: spc_profile.config.name.clone(),
626                space: spc_profile.config.space.clone(),
627                version: spc_profile.config.version.clone(),
628                active: false,
629                drift_type: Some(DriftType::Spc),
630                deactivate_others: false,
631            },
632        )
633        .await
634        .unwrap();
635    }
636
637    #[tokio::test]
638    async fn test_postgres_get_features() {
639        let pool = db_pool().await;
640
641        let timestamp = Utc::now();
642
643        for _ in 0..10 {
644            let mut records = Vec::new();
645            for j in 0..10 {
646                let record = SpcServerRecord {
647                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
648                    space: SPACE.to_string(),
649                    name: NAME.to_string(),
650                    version: VERSION.to_string(),
651                    feature: format!("test{j}"),
652                    value: j as f64,
653                };
654
655                records.push(record);
656            }
657
658            let result = PostgresClient::insert_spc_drift_records_batch(&pool, &records)
659                .await
660                .unwrap();
661            assert_eq!(result.rows_affected(), records.len() as u64);
662        }
663
664        let service_info = ServiceInfo {
665            space: SPACE.to_string(),
666            name: NAME.to_string(),
667            version: VERSION.to_string(),
668        };
669
670        let features = PostgresClient::get_spc_features(&pool, &service_info)
671            .await
672            .unwrap();
673        assert_eq!(features.len(), 10);
674
675        let records =
676            PostgresClient::get_spc_drift_records(&pool, &service_info, &timestamp, &features)
677                .await
678                .unwrap();
679
680        assert_eq!(records.features.len(), 10);
681
682        let binned_records = PostgresClient::get_binned_spc_drift_records(
683            &pool,
684            &DriftRequest {
685                space: SPACE.to_string(),
686                name: NAME.to_string(),
687                version: VERSION.to_string(),
688                time_interval: TimeInterval::FiveMinutes,
689                max_data_points: 10,
690                drift_type: DriftType::Spc,
691                ..Default::default()
692            },
693            &DatabaseSettings::default().retention_period,
694            &ObjectStorageSettings::default(),
695        )
696        .await
697        .unwrap();
698
699        assert_eq!(binned_records.features.len(), 10);
700    }
701
702    #[tokio::test]
703    async fn test_postgres_bin_proportions() {
704        let pool = db_pool().await;
705
706        let timestamp = Utc::now();
707
708        let num_features = 3;
709        let num_bins = 5;
710
711        let features = (0..=num_features)
712            .map(|feature| {
713                let bins = (0..=num_bins)
714                    .map(|bind_id| Bin {
715                        id: bind_id,
716                        lower_limit: None,
717                        upper_limit: None,
718                        proportion: 0.0,
719                    })
720                    .collect();
721                let feature_name = format!("feature{feature}");
722                let feature_profile = PsiFeatureDriftProfile {
723                    id: feature_name.clone(),
724                    bins,
725                    timestamp,
726                    bin_type: BinType::Numeric,
727                };
728                (feature_name, feature_profile)
729            })
730            .collect();
731
732        let profile = &DriftProfile::Psi(psi::PsiDriftProfile::new(
733            features,
734            PsiDriftConfig {
735                space: SPACE.to_string(),
736                name: NAME.to_string(),
737                version: VERSION.to_string(),
738                ..Default::default()
739            },
740        ));
741        let _ = insert_profile_to_db(&pool, profile, false, false).await;
742
743        for feature in 0..num_features {
744            for bin in 0..=num_bins {
745                let mut records = Vec::new();
746                for j in 0..=100 {
747                    let record = PsiServerRecord {
748                        created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
749                        space: SPACE.to_string(),
750                        name: NAME.to_string(),
751                        version: VERSION.to_string(),
752                        feature: format!("feature{feature}"),
753                        bin_id: bin,
754                        bin_count: rand::rng().random_range(0..10),
755                    };
756
757                    records.push(record);
758                }
759                PostgresClient::insert_bin_counts_batch(&pool, &records)
760                    .await
761                    .unwrap();
762            }
763        }
764
765        let binned_records = PostgresClient::get_feature_distributions(
766            &pool,
767            &ServiceInfo {
768                space: SPACE.to_string(),
769                name: NAME.to_string(),
770                version: VERSION.to_string(),
771            },
772            &timestamp,
773            &["feature0".to_string()],
774        )
775        .await
776        .unwrap();
777
778        // assert binned_records.features["test"]["decile_1"] is around .5
779        let bin_proportion = binned_records
780            .distributions
781            .get("feature0")
782            .unwrap()
783            .bins
784            .get(&1)
785            .unwrap();
786
787        assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
788
789        let binned_records = PostgresClient::get_binned_psi_drift_records(
790            &pool,
791            &DriftRequest {
792                space: SPACE.to_string(),
793                name: NAME.to_string(),
794                version: VERSION.to_string(),
795                time_interval: TimeInterval::OneHour,
796                max_data_points: 1000,
797                drift_type: DriftType::Psi,
798                ..Default::default()
799            },
800            &DatabaseSettings::default().retention_period,
801            &ObjectStorageSettings::default(),
802        )
803        .await
804        .unwrap();
805        //
806        assert_eq!(binned_records.len(), 3);
807    }
808
809    #[tokio::test]
810    async fn test_postgres_cru_custom_metric() {
811        let pool = db_pool().await;
812
813        let timestamp = Utc::now();
814
815        for i in 0..2 {
816            let mut records = Vec::new();
817            for j in 0..25 {
818                let record = CustomMetricServerRecord {
819                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
820                    space: SPACE.to_string(),
821                    name: NAME.to_string(),
822                    version: VERSION.to_string(),
823                    metric: format!("metric{i}"),
824                    value: rand::rng().random_range(0..10) as f64,
825                };
826                records.push(record);
827            }
828            let result = PostgresClient::insert_custom_metric_values_batch(&pool, &records)
829                .await
830                .unwrap();
831            assert_eq!(result.rows_affected(), 25);
832        }
833
834        // insert random record to test has statistics funcs handle single record
835        let record = CustomMetricServerRecord {
836            created_at: Utc::now(),
837            space: SPACE.to_string(),
838            name: NAME.to_string(),
839            version: VERSION.to_string(),
840            metric: "metric3".to_string(),
841            value: rand::rng().random_range(0..10) as f64,
842        };
843
844        let result = PostgresClient::insert_custom_metric_values_batch(&pool, &[record])
845            .await
846            .unwrap();
847        assert_eq!(result.rows_affected(), 1);
848
849        let metrics = PostgresClient::get_custom_metric_values(
850            &pool,
851            &ServiceInfo {
852                space: SPACE.to_string(),
853                name: NAME.to_string(),
854                version: VERSION.to_string(),
855            },
856            &timestamp,
857            &["metric1".to_string()],
858        )
859        .await
860        .unwrap();
861
862        assert_eq!(metrics.len(), 1);
863
864        let binned_records = PostgresClient::get_binned_custom_drift_records(
865            &pool,
866            &DriftRequest {
867                space: SPACE.to_string(),
868                name: NAME.to_string(),
869                version: VERSION.to_string(),
870                time_interval: TimeInterval::OneHour,
871                max_data_points: 1000,
872                drift_type: DriftType::Custom,
873                ..Default::default()
874            },
875            &DatabaseSettings::default().retention_period,
876            &ObjectStorageSettings::default(),
877        )
878        .await
879        .unwrap();
880        //
881        assert_eq!(binned_records.metrics.len(), 3);
882    }
883
884    #[tokio::test]
885    async fn test_postgres_user() {
886        let pool = db_pool().await;
887        let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
888
889        // Create
890        let user = User::new(
891            "user".to_string(),
892            "pass".to_string(),
893            "email".to_string(),
894            recovery_codes,
895            None,
896            None,
897            None,
898            None,
899        );
900        PostgresClient::insert_user(&pool, &user).await.unwrap();
901
902        // Read
903        let mut user = PostgresClient::get_user(&pool, "user")
904            .await
905            .unwrap()
906            .unwrap();
907
908        assert_eq!(user.username, "user");
909        assert_eq!(user.group_permissions, vec!["user"]);
910        assert_eq!(user.email, "email");
911
912        // update user
913        user.active = false;
914        user.refresh_token = Some("token".to_string());
915
916        // Update
917        PostgresClient::update_user(&pool, &user).await.unwrap();
918        let user = PostgresClient::get_user(&pool, "user")
919            .await
920            .unwrap()
921            .unwrap();
922        assert!(!user.active);
923        assert_eq!(user.refresh_token.unwrap(), "token");
924
925        // get users
926        let users = PostgresClient::get_users(&pool).await.unwrap();
927        assert_eq!(users.len(), 1);
928
929        // get last admin
930        let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
931        assert!(!is_last_admin);
932
933        // delete
934        PostgresClient::delete_user(&pool, "user").await.unwrap();
935    }
936
937    #[tokio::test]
938    async fn test_postgres_llm_drift_record_insert_get() {
939        let pool = db_pool().await;
940
941        let input = "This is a test input";
942        let output = "This is a test response";
943        let prompt = create_score_prompt(None);
944
945        for j in 0..10 {
946            let context = serde_json::json!({
947                "input": input,
948                "response": output,
949            });
950            let record = LLMDriftServerRecord {
951                created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
952                space: SPACE.to_string(),
953                name: NAME.to_string(),
954                version: VERSION.to_string(),
955                prompt: Some(prompt.model_dump_value()),
956                context,
957                status: Status::Pending,
958                id: 0, // This will be set by the database
959                uid: "test".to_string(),
960                updated_at: None,
961                score: Value::Null,
962                processing_started_at: None,
963                processing_ended_at: None,
964                processing_duration: None,
965            };
966
967            let result = PostgresClient::insert_llm_drift_record(&pool, &record)
968                .await
969                .unwrap();
970
971            assert_eq!(result.rows_affected(), 1);
972        }
973
974        let service_info = ServiceInfo {
975            space: SPACE.to_string(),
976            name: NAME.to_string(),
977            version: VERSION.to_string(),
978        };
979
980        let features = PostgresClient::get_llm_drift_records(&pool, &service_info, None, None)
981            .await
982            .unwrap();
983        assert_eq!(features.len(), 10);
984
985        // get pending task
986        let pending_tasks = PostgresClient::get_pending_llm_drift_record(&pool)
987            .await
988            .unwrap();
989
990        // assert not empty
991        assert!(pending_tasks.is_some());
992
993        // get pending task with space, name, version
994        let task_input = &pending_tasks.as_ref().unwrap().context["input"];
995        assert_eq!(*task_input, "This is a test input".to_string());
996
997        // update pending task
998        PostgresClient::update_llm_drift_record_status(
999            &pool,
1000            &pending_tasks.unwrap(),
1001            Status::Processed,
1002            Some(1),
1003        )
1004        .await
1005        .unwrap();
1006
1007        // query processed tasks
1008        let processed_tasks = PostgresClient::get_llm_drift_records(
1009            &pool,
1010            &service_info,
1011            None,
1012            Some(Status::Processed),
1013        )
1014        .await
1015        .unwrap();
1016
1017        // assert not empty
1018        assert_eq!(processed_tasks.len(), 1);
1019    }
1020
1021    #[tokio::test]
1022    async fn test_postgres_llm_drift_record_pagination() {
1023        let pool = db_pool().await;
1024
1025        let input = "This is a test input";
1026        let output = "This is a test response";
1027        let prompt = create_score_prompt(None);
1028
1029        for j in 0..10 {
1030            let context = serde_json::json!({
1031                "input": input,
1032                "response": output,
1033            });
1034            let record = LLMDriftServerRecord {
1035                created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1036                space: SPACE.to_string(),
1037                name: NAME.to_string(),
1038                version: VERSION.to_string(),
1039                prompt: Some(prompt.model_dump_value()),
1040                context,
1041                score: Value::Null,
1042                status: Status::Pending,
1043                id: 0, // This will be set by the database
1044                uid: "test".to_string(),
1045                updated_at: None,
1046                processing_started_at: None,
1047                processing_ended_at: None,
1048                processing_duration: None,
1049            };
1050
1051            let result = PostgresClient::insert_llm_drift_record(&pool, &record)
1052                .await
1053                .unwrap();
1054
1055            assert_eq!(result.rows_affected(), 1);
1056        }
1057
1058        let service_info = ServiceInfo {
1059            space: SPACE.to_string(),
1060            name: NAME.to_string(),
1061            version: VERSION.to_string(),
1062        };
1063
1064        // Get paginated records (1st page)
1065        let pagination = PaginationRequest {
1066            limit: 5,
1067            cursor: None, // Start from the beginning
1068        };
1069
1070        let paginated_features = PostgresClient::get_llm_drift_records_pagination(
1071            &pool,
1072            &service_info,
1073            None,
1074            pagination,
1075        )
1076        .await
1077        .unwrap();
1078
1079        assert_eq!(paginated_features.items.len(), 5);
1080        assert!(paginated_features.next_cursor.is_some());
1081
1082        // get id of the most recent record in the first page
1083        let last_record = paginated_features.items.first().unwrap();
1084
1085        // Get paginated records (2nd page)
1086        let next_cursor = paginated_features.next_cursor.unwrap();
1087        let pagination = PaginationRequest {
1088            limit: 5,
1089            cursor: Some(next_cursor),
1090        };
1091
1092        let paginated_features = PostgresClient::get_llm_drift_records_pagination(
1093            &pool,
1094            &service_info,
1095            None,
1096            pagination,
1097        )
1098        .await
1099        .unwrap();
1100
1101        assert_eq!(paginated_features.items.len(), 5);
1102        assert!(paginated_features.next_cursor.is_none());
1103
1104        // get last record of the second page
1105        let first_record = paginated_features.items.last().unwrap();
1106
1107        let diff = last_record.id - first_record.id + 1; // +1 because IDs are inclusive
1108        assert!(diff == 10);
1109    }
1110
1111    #[tokio::test]
1112    async fn test_postgres_llm_metrics_insert_get() {
1113        let pool = db_pool().await;
1114
1115        let timestamp = Utc::now();
1116
1117        for i in 0..2 {
1118            let mut records = Vec::new();
1119            for j in 0..25 {
1120                let record = LLMMetricRecord {
1121                    record_uid: format!("uid{i}{j}"),
1122                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1123                    space: SPACE.to_string(),
1124                    name: NAME.to_string(),
1125                    version: VERSION.to_string(),
1126                    metric: format!("metric{i}"),
1127                    value: rand::rng().random_range(0..10) as f64,
1128                };
1129                records.push(record);
1130            }
1131            let result = PostgresClient::insert_llm_metric_values_batch(&pool, &records)
1132                .await
1133                .unwrap();
1134            assert_eq!(result.rows_affected(), 25);
1135        }
1136
1137        let metrics = PostgresClient::get_llm_metric_values(
1138            &pool,
1139            &ServiceInfo {
1140                space: SPACE.to_string(),
1141                name: NAME.to_string(),
1142                version: VERSION.to_string(),
1143            },
1144            &timestamp,
1145            &["metric1".to_string()],
1146        )
1147        .await
1148        .unwrap();
1149
1150        assert_eq!(metrics.len(), 1);
1151        let binned_records = PostgresClient::get_binned_llm_metric_values(
1152            &pool,
1153            &DriftRequest {
1154                space: SPACE.to_string(),
1155                name: NAME.to_string(),
1156                version: VERSION.to_string(),
1157                time_interval: TimeInterval::OneHour,
1158                max_data_points: 1000,
1159                drift_type: DriftType::LLM,
1160                ..Default::default()
1161            },
1162            &DatabaseSettings::default().retention_period,
1163            &ObjectStorageSettings::default(),
1164        )
1165        .await
1166        .unwrap();
1167        //
1168        assert_eq!(binned_records.metrics.len(), 2);
1169    }
1170
1171    #[tokio::test]
1172    async fn test_postgres_tracing() {
1173        let pool = db_pool().await;
1174        let script = std::fs::read_to_string("src/tests/script/populate_trace.sql").unwrap();
1175        sqlx::query(&script).execute(&pool).await.unwrap();
1176        let mut filters = TraceFilters::default();
1177
1178        let first_batch = PostgresClient::get_traces_paginated(&pool, filters.clone())
1179            .await
1180            .unwrap();
1181
1182        assert_eq!(
1183            first_batch.items.len(),
1184            50,
1185            "First batch should have 50 records"
1186        );
1187
1188        // test pagination (get last record created_at and trace_id)
1189        let last_record = first_batch.next_cursor.unwrap();
1190        filters = filters.next_page(&last_record);
1191
1192        let next_batch = PostgresClient::get_traces_paginated(&pool, filters.clone())
1193            .await
1194            .unwrap();
1195
1196        // should be another 50 records
1197        assert_eq!(
1198            next_batch.items.len(),
1199            50,
1200            "Next batch should have 50 records"
1201        );
1202
1203        // assert next_batch first record timestamp is <= first_batch last record timestamp
1204        let next_first_record = next_batch.items.first().unwrap();
1205        assert!(
1206            next_first_record.created_at <= last_record.created_at,
1207            "Next batch first record timestamp is not less than or equal to last record timestamp"
1208        );
1209
1210        // test pagination for previous
1211        filters = filters.previous_page(&next_batch.previous_cursor.unwrap());
1212        let previous_batch = PostgresClient::get_traces_paginated(&pool, filters.clone())
1213            .await
1214            .unwrap();
1215        assert_eq!(
1216            previous_batch.items.len(),
1217            50,
1218            "Previous batch should have 50 records"
1219        );
1220
1221        // Filter records to find item with >5 spans
1222        let filtered_record = first_batch
1223            .items
1224            .iter()
1225            .find(|record| record.span_count > Some(5))
1226            .unwrap();
1227
1228        filters.cursor_created_at = None;
1229        filters.cursor_trace_id = None;
1230        filters.space = Some(filtered_record.space.clone());
1231        filters.name = Some(filtered_record.name.clone());
1232        filters.version = Some(filtered_record.version.clone());
1233
1234        let records = PostgresClient::get_traces_paginated(&pool, filters.clone())
1235            .await
1236            .unwrap();
1237
1238        // Records are randomly generated, so just assert we get some records back
1239        assert!(
1240            !records.items.is_empty(),
1241            "Should return records with specified filters"
1242        );
1243
1244        // get spans for filtered trace
1245        let spans = PostgresClient::get_trace_spans(&pool, &filtered_record.trace_id)
1246            .await
1247            .unwrap();
1248
1249        assert!(spans.len() == filtered_record.span_count.unwrap() as usize);
1250
1251        let start_time = filtered_record.created_at - chrono::Duration::hours(24);
1252        let end_time = filtered_record.created_at + chrono::Duration::minutes(5);
1253
1254        // make request for trace metrics
1255        let trace_metrics = PostgresClient::get_trace_metrics(
1256            &pool,
1257            None,
1258            None,
1259            None,
1260            start_time,
1261            end_time,
1262            "60 minutes",
1263        )
1264        .await
1265        .unwrap();
1266
1267        // assert we have data points
1268        assert!(trace_metrics.len() >= 10);
1269    }
1270
1271    #[tokio::test]
1272    async fn test_postgres_tracing_insert() {
1273        let pool = db_pool().await;
1274
1275        // create parent trace
1276        let mut trace_record = random_trace_record();
1277        let trace_id = trace_record.trace_id.clone();
1278
1279        // create spans
1280        let root_span = random_span_record(&trace_id, None);
1281        let child_span = random_span_record(&trace_id, Some(&root_span.span_id));
1282
1283        // set root span id in trace record
1284        trace_record.root_span_id = root_span.span_id.clone();
1285
1286        // this should perform an insert
1287        let result = PostgresClient::upsert_trace_batch(&pool, &[trace_record.clone()])
1288            .await
1289            .unwrap();
1290
1291        assert_eq!(result.rows_affected(), 1);
1292
1293        // this should perform an update (mainly just increasing span count)
1294        let result = PostgresClient::upsert_trace_batch(&pool, &[trace_record.clone()])
1295            .await
1296            .unwrap();
1297
1298        assert_eq!(result.rows_affected(), 1);
1299
1300        // insert spans
1301        let result =
1302            PostgresClient::insert_span_batch(&pool, &[root_span.clone(), child_span.clone()])
1303                .await
1304                .unwrap();
1305
1306        assert_eq!(result.rows_affected(), 2);
1307
1308        // refresh materialized view
1309        sqlx::query("REFRESH MATERIALIZED VIEW scouter.trace_summary;")
1310            .execute(&pool)
1311            .await
1312            .unwrap();
1313
1314        let inserted_created_at = trace_record.created_at;
1315        let inserted_trace_id = trace_record.trace_id.clone();
1316
1317        let trace_filter = TraceFilters {
1318            cursor_created_at: Some(inserted_created_at + Duration::days(1)),
1319            cursor_trace_id: Some(inserted_trace_id),
1320            start_time: Some(inserted_created_at - Duration::minutes(5)),
1321            end_time: Some(inserted_created_at + Duration::days(1)),
1322            ..TraceFilters::default()
1323        };
1324
1325        let traces = PostgresClient::get_traces_paginated(&pool, trace_filter)
1326            .await
1327            .unwrap();
1328
1329        assert_eq!(traces.items.len(), 1);
1330        let retrieved_trace = &traces.items[0];
1331        // assert span count is 2
1332        assert_eq!(retrieved_trace.span_count.unwrap(), 2);
1333
1334        let baggage = TraceBaggageRecord {
1335            created_at: Utc::now(),
1336            trace_id: trace_record.trace_id.clone(),
1337            scope: "test_scope".to_string(),
1338            key: "user_id".to_string(),
1339            value: "12345".to_string(),
1340        };
1341
1342        let result =
1343            PostgresClient::insert_trace_baggage_batch(&pool, std::slice::from_ref(&baggage))
1344                .await
1345                .unwrap();
1346
1347        assert_eq!(result.rows_affected(), 1);
1348
1349        let retrieved_baggage =
1350            PostgresClient::get_trace_baggage_records(&pool, &trace_record.trace_id)
1351                .await
1352                .unwrap();
1353
1354        assert_eq!(retrieved_baggage.len(), 1);
1355    }
1356
1357    #[tokio::test]
1358    async fn test_postgres_tags() {
1359        let pool = db_pool().await;
1360        let uid = create_uuid7();
1361
1362        let tag1 = TagRecord {
1363            created_at: Utc::now(),
1364            entity_id: uid.clone(),
1365            entity_type: "service".to_string(),
1366            key: "env".to_string(),
1367            value: "production".to_string(),
1368        };
1369
1370        let tag2 = TagRecord {
1371            created_at: Utc::now(),
1372            entity_id: uid.clone(),
1373            entity_type: "service".to_string(),
1374            key: "team".to_string(),
1375            value: "backend".to_string(),
1376        };
1377
1378        let result = PostgresClient::insert_tag_batch(&pool, &[tag1.clone(), tag2.clone()])
1379            .await
1380            .unwrap();
1381
1382        assert_eq!(result.rows_affected(), 2);
1383
1384        let tags = PostgresClient::get_tags(&pool, "service", &uid)
1385            .await
1386            .unwrap();
1387
1388        assert_eq!(tags.len(), 2);
1389    }
1390}