Skip to main content

scouter_sql/sql/
postgres.rs

1use crate::sql::cache::entity_cache;
2use crate::sql::cache::init_entity_cache;
3use crate::sql::error::SqlError;
4use crate::sql::traits::{
5    AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, GenAIDriftSqlLogic,
6    ObservabilitySqlLogic, ProfileSqlLogic, PsiSqlLogic, SpcSqlLogic, TagSqlLogic, TraceSqlLogic,
7    UserSqlLogic,
8};
9use scouter_settings::DatabaseSettings;
10use scouter_types::{RecordType, ServerRecords, TagRecord, ToDriftRecords, TraceServerRecord};
11use sqlx::ConnectOptions;
12use sqlx::{postgres::PgConnectOptions, Pool, Postgres};
13use std::result::Result::Ok;
14use tokio::try_join;
15use tracing::log::LevelFilter;
16use tracing::{debug, error, info, instrument};
17
18#[derive(Debug, Clone)]
19#[allow(dead_code)]
20pub struct PostgresClient {}
21
22impl SpcSqlLogic for PostgresClient {}
23impl CustomMetricSqlLogic for PostgresClient {}
24impl PsiSqlLogic for PostgresClient {}
25impl GenAIDriftSqlLogic for PostgresClient {}
26impl UserSqlLogic for PostgresClient {}
27impl ProfileSqlLogic for PostgresClient {}
28impl ObservabilitySqlLogic for PostgresClient {}
29impl AlertSqlLogic for PostgresClient {}
30impl ArchiveSqlLogic for PostgresClient {}
31impl TraceSqlLogic for PostgresClient {}
32impl TagSqlLogic for PostgresClient {}
33
34impl PostgresClient {
35    /// Setup the application with the given database pool.
36    ///
37    /// # Returns
38    ///
39    /// * `Result<Pool<Postgres>, anyhow::Error>` - Result of the database pool
40    #[instrument(skip(database_settings))]
41    pub async fn create_db_pool(
42        database_settings: &DatabaseSettings,
43    ) -> Result<Pool<Postgres>, SqlError> {
44        let mut opts: PgConnectOptions = database_settings.connection_uri.parse()?;
45
46        // Sqlx logs a lot of debug information by default, which can be overwhelming.
47        opts = opts.log_statements(LevelFilter::Off);
48
49        let pool = match sqlx::postgres::PgPoolOptions::new()
50            .max_connections(database_settings.max_connections)
51            .connect_with(opts)
52            .await
53        {
54            Ok(pool) => {
55                info!("✅ Successfully connected to database");
56                pool
57            }
58            Err(err) => {
59                error!("🚨 Failed to connect to database {:?}", err);
60                std::process::exit(1);
61            }
62        };
63
64        // setup entity cache
65        init_entity_cache(1000);
66
67        // Run migrations
68        if let Err(err) = Self::run_migrations(&pool).await {
69            error!("🚨 Failed to run migrations {:?}", err);
70            std::process::exit(1);
71        }
72
73        Ok(pool)
74    }
75
76    pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
77        info!("Running migrations");
78        sqlx::migrate!("src/migrations")
79            .run(pool)
80            .await
81            .map_err(SqlError::MigrateError)?;
82
83        debug!("Migrations complete");
84
85        Ok(())
86    }
87}
88
89pub struct MessageHandler {}
90
91impl MessageHandler {
92    const DEFAULT_BATCH_SIZE: usize = 500;
93    #[instrument(skip_all)]
94    pub async fn insert_server_records(
95        pool: &Pool<Postgres>,
96        records: ServerRecords,
97    ) -> Result<(), SqlError> {
98        debug!("Inserting server records: {:?}", records.record_type()?);
99
100        let entity_id = entity_cache()
101            .get_entity_id_from_uid(pool, records.uid()?)
102            .await
103            .inspect_err(|e| error!("Failed to get entity ID from UID: {:?}", e))?;
104
105        match records.record_type()? {
106            RecordType::Spc => {
107                let spc_records = records.to_spc_drift_records()?;
108                debug!("SPC record count: {}", spc_records.len());
109
110                for chunk in spc_records.chunks(Self::DEFAULT_BATCH_SIZE) {
111                    PostgresClient::insert_spc_drift_records_batch(pool, chunk, &entity_id)
112                        .await
113                        .map_err(|e| {
114                            error!("Failed to insert SPC drift records batch: {:?}", e);
115                            e
116                        })?;
117                }
118            }
119
120            RecordType::Psi => {
121                let psi_records = records.to_psi_drift_records()?;
122                debug!("PSI record count: {}", psi_records.len());
123
124                for chunk in psi_records.chunks(Self::DEFAULT_BATCH_SIZE) {
125                    PostgresClient::insert_bin_counts_batch(pool, chunk, &entity_id)
126                        .await
127                        .map_err(|e| {
128                            error!("Failed to insert PSI drift records batch: {:?}", e);
129                            e
130                        })?;
131                }
132            }
133            RecordType::Custom => {
134                let custom_records = records.to_custom_metric_drift_records()?;
135                debug!("Custom record count: {}", custom_records.len());
136
137                for chunk in custom_records.chunks(Self::DEFAULT_BATCH_SIZE) {
138                    PostgresClient::insert_custom_metric_values_batch(pool, chunk, &entity_id)
139                        .await
140                        .map_err(|e| {
141                            error!("Failed to insert custom metric records batch: {:?}", e);
142                            e
143                        })?;
144                }
145            }
146
147            RecordType::GenAIEval => {
148                debug!("LLM Drift record count: {:?}", records.len());
149                let records = records.to_genai_eval_records()?;
150                for record in records {
151                    let _ = PostgresClient::insert_genai_eval_record(pool, record, &entity_id)
152                        .await
153                        .map_err(|e| {
154                            error!("Failed to insert GenAI drift record: {:?}", e);
155                        });
156                }
157            }
158
159            RecordType::GenAITask => {
160                debug!("GenAI Task count: {:?}", records.len());
161                let records = records.to_genai_task_records()?;
162                for chunk in records.chunks(Self::DEFAULT_BATCH_SIZE) {
163                    PostgresClient::insert_eval_task_results_batch(pool, chunk, &entity_id)
164                        .await
165                        .map_err(|e| {
166                            error!("Failed to insert GenAI task records batch: {:?}", e);
167                            e
168                        })?;
169                }
170            }
171
172            RecordType::GenAIWorkflow => {
173                debug!("GenAI Workflow count: {:?}", records.len());
174                let records = records.to_genai_workflow_records()?;
175                for record in records {
176                    let _ = PostgresClient::insert_genai_eval_workflow_record(
177                        pool, &record, &entity_id,
178                    )
179                    .await
180                    .map_err(|e| {
181                        error!("Failed to insert GenAI workflow record: {:?}", e);
182                    });
183                }
184            }
185
186            _ => {
187                error!(
188                    "Unsupported record type for batch insert: {:?}",
189                    records.record_type()?
190                );
191                return Err(SqlError::UnsupportedBatchTypeError);
192            }
193        }
194
195        Ok(())
196    }
197
198    pub async fn insert_trace_server_record(
199        pool: &Pool<Postgres>,
200        records: TraceServerRecord,
201    ) -> Result<(), SqlError> {
202        let (span_batch, baggage_batch, tag_records) = records.to_records()?;
203
204        let (span_result, baggage_result, tag_result) = try_join!(
205            PostgresClient::insert_span_batch(pool, &span_batch),
206            PostgresClient::insert_trace_baggage_batch(pool, &baggage_batch),
207            PostgresClient::insert_tag_batch(pool, &tag_records),
208        )?;
209
210        debug!(
211            span_rows = span_result.rows_affected(),
212            baggage_rows = baggage_result.rows_affected(),
213            total_spans = span_batch.len(),
214            total_baggage = baggage_batch.len(),
215            tag_rows = tag_result.rows_affected(),
216            "Successfully inserted trace server records"
217        );
218        Ok(())
219    }
220
221    pub async fn insert_tag_record(
222        pool: &Pool<Postgres>,
223        record: TagRecord,
224    ) -> Result<(), SqlError> {
225        let result = PostgresClient::insert_tag_batch(pool, std::slice::from_ref(&record)).await?;
226
227        debug!(
228            rows_affected = result.rows_affected(),
229            entity_type = record.entity_type.as_str(),
230            entity_id = record.entity_id.as_str(),
231            key = record.key.as_str(),
232            "Successfully inserted tag record"
233        );
234
235        Ok(())
236    }
237}
238
239/// Runs database integratino tests
240/// Note - binned queries targeting custom intervals with long-term and short-term data are
241/// done in the scouter-server integration tests
242#[cfg(test)]
243mod tests {
244
245    use std::collections::HashMap;
246
247    use super::*;
248    use crate::sql::schema::User;
249    use crate::sql::traits::EntitySqlLogic;
250    use chrono::{Duration, Utc};
251    use potato_head::create_uuid7;
252    use rand::Rng;
253    use scouter_semver::VersionType;
254    use scouter_settings::ObjectStorageSettings;
255    use scouter_types::genai::ExecutionPlan;
256    use scouter_types::psi::{Bin, BinType, PsiDriftConfig, PsiFeatureDriftProfile};
257    use scouter_types::spc::SpcDriftProfile;
258    use scouter_types::sql::TraceFilters;
259    use scouter_types::*;
260    use serde_json::Value;
261
262    const SPACE: &str = "space";
263    const NAME: &str = "name";
264    const VERSION: &str = "1.0.0";
265    const SCOPE: &str = "scope";
266    const UID: &str = "test";
267    const ENTITY_ID: i32 = 9999;
268
269    fn random_trace_record() -> TraceRecord {
270        let mut rng = rand::rng();
271        let random_num = rng.random_range(0..1000);
272        let trace_id: String = (0..32)
273            .map(|_| format!("{:x}", rng.random_range(0..16)))
274            .collect();
275        let span_id: String = (0..16)
276            .map(|_| format!("{:x}", rng.random_range(0..16)))
277            .collect();
278        let created_at = Utc::now() + chrono::Duration::milliseconds(random_num);
279
280        TraceRecord {
281            trace_id: trace_id.clone(),
282            created_at,
283            service_name: format!("service_{}", random_num % 10),
284            scope: SCOPE.to_string(),
285            trace_state: "running".to_string(),
286            start_time: created_at,
287            end_time: created_at + chrono::Duration::milliseconds(150),
288            duration_ms: 150,
289            status_code: 0,
290            span_count: 1,
291            status_message: "OK".to_string(),
292            root_span_id: span_id.clone(),
293            tags: vec![],
294            process_attributes: vec![],
295        }
296    }
297
298    fn random_span_record(
299        trace_id: &str,
300        parent_span_id: Option<&str>,
301        service_name: &str,
302    ) -> TraceSpanRecord {
303        let mut rng = rand::rng();
304        let span_id: String = (0..16)
305            .map(|_| format!("{:x}", rng.random_range(0..16)))
306            .collect();
307
308        let random_offset_ms = rng.random_range(0..1000);
309        let duration_ms_val = rng.random_range(50..500);
310
311        let created_at = Utc::now() + chrono::Duration::milliseconds(random_offset_ms);
312        let start_time = created_at;
313        let end_time = start_time + chrono::Duration::milliseconds(duration_ms_val);
314
315        // --- Status and Kind ---
316        let status_code = if rng.random_bool(0.95) { 0 } else { 2 };
317        let span_kind_options = ["SERVER", "CLIENT", "INTERNAL", "PRODUCER", "CONSUMER"];
318        let span_kind = span_kind_options[rng.random_range(0..span_kind_options.len())].to_string();
319
320        TraceSpanRecord {
321            created_at,
322            span_id,
323            trace_id: trace_id.to_string(),
324            parent_span_id: parent_span_id.map(|s| s.to_string()),
325            service_name: service_name.to_string(),
326            scope: SCOPE.to_string(),
327            span_name: format!("{}_{}", "random_operation", rng.random_range(0..10)),
328            span_kind,
329            start_time,
330            end_time,
331            duration_ms: duration_ms_val,
332            status_code,
333            status_message: if status_code == 2 {
334                "Internal Server Error".to_string()
335            } else {
336                "OK".to_string()
337            },
338            attributes: vec![Attribute::default()],
339            events: vec![],
340            links: vec![],
341            label: None,
342            input: Value::default(),
343            output: Value::default(),
344            resource_attributes: vec![],
345        }
346    }
347
348    pub async fn cleanup(pool: &Pool<Postgres>) {
349        sqlx::raw_sql(
350            r#"
351            DELETE
352            FROM scouter.service_entities;
353
354            DELETE
355            FROM scouter.drift_entities;
356
357            DELETE
358            FROM scouter.spc_drift;
359
360            DELETE
361            FROM scouter.observability_metric;
362
363            DELETE
364            FROM scouter.custom_drift;
365
366            DELETE
367            FROM scouter.drift_alert;
368
369            DELETE
370            FROM scouter.drift_profile;
371
372            DELETE
373            FROM scouter.psi_drift;
374
375            DELETE
376            FROM scouter.user;
377
378            DELETE
379            FROM scouter.genai_eval_record;
380
381            DELETE
382            FROM scouter.genai_eval_task;
383
384            DELETE
385            FROM scouter.genai_eval_workflow;
386
387            DELETE
388            FROM scouter.spans;
389
390            DELETE
391            FROM scouter.trace_baggage;
392
393            DELETE
394            FROM scouter.tags;
395            "#,
396        )
397        .fetch_all(pool)
398        .await
399        .unwrap();
400    }
401
402    pub async fn db_pool() -> Pool<Postgres> {
403        let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
404            .await
405            .unwrap();
406        cleanup(&pool).await;
407        pool
408    }
409
410    pub async fn insert_profile_to_db(
411        pool: &Pool<Postgres>,
412        profile: &DriftProfile,
413        active: bool,
414        deactivate_others: bool,
415    ) -> String {
416        let base_args = profile.get_base_args();
417        let version_request = VersionRequest {
418            version: None,
419            version_type: VersionType::Minor,
420            pre_tag: None,
421            build_tag: None,
422        };
423        let version = PostgresClient::get_next_profile_version(pool, &base_args, version_request)
424            .await
425            .unwrap();
426
427        let result = PostgresClient::insert_drift_profile(
428            pool,
429            profile,
430            &base_args,
431            &version,
432            &active,
433            &deactivate_others,
434        )
435        .await
436        .unwrap();
437
438        result
439    }
440
441    #[tokio::test]
442    async fn test_postgres_start() {
443        let _pool = db_pool().await;
444    }
445
446    #[tokio::test]
447    async fn test_postgres_drift_alert() {
448        let pool = db_pool().await;
449        let entity_id = 9999;
450
451        // Insert 10 alerts with slight delays to ensure different timestamps
452        for i in 0..10 {
453            let alert = AlertMap::Custom(custom::ComparisonMetricAlert {
454                metric_name: "test".to_string(),
455                baseline_value: 0 as f64,
456                observed_value: i as f64,
457                delta: None,
458                alert_threshold: AlertThreshold::Above,
459            });
460
461            let result = PostgresClient::insert_drift_alert(&pool, &entity_id, &alert)
462                .await
463                .unwrap();
464
465            assert_eq!(result.rows_affected(), 1);
466
467            // Small delay to ensure timestamp ordering
468            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
469        }
470
471        // Test 1: Get first page with default limit (50)
472        let request = DriftAlertPaginationRequest {
473            uid: UID.to_string(),
474            active: Some(true),
475            limit: Some(50),
476            ..Default::default()
477        };
478
479        let response = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
480            .await
481            .unwrap();
482
483        assert_eq!(response.items.len(), 10);
484        assert!(!response.has_next); // No more items
485        assert!(!response.has_previous); // First page
486        assert!(response.next_cursor.is_none());
487        assert!(response.previous_cursor.is_some());
488
489        // Test 2: Paginate with limit of 3 - forward direction
490        let request = DriftAlertPaginationRequest {
491            uid: UID.to_string(),
492            active: Some(true),
493            limit: Some(3),
494            direction: Some("next".to_string()),
495            ..Default::default()
496        };
497
498        let page1 = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
499            .await
500            .unwrap();
501
502        assert_eq!(page1.items.len(), 3);
503        assert!(page1.has_next);
504        assert!(!page1.has_previous);
505        assert!(page1.next_cursor.is_some());
506        assert!(page1.previous_cursor.is_some());
507
508        // Test 3: Get second page using next cursor
509        let next_cursor = page1.next_cursor.unwrap();
510        let request = DriftAlertPaginationRequest {
511            uid: UID.to_string(),
512            active: Some(true),
513            limit: Some(3),
514            cursor_created_at: Some(next_cursor.created_at),
515            cursor_id: Some(next_cursor.id as i32),
516            direction: Some("next".to_string()),
517            start_datetime: None,
518            end_datetime: None,
519        };
520
521        let page2 = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
522            .await
523            .unwrap();
524
525        assert_eq!(page2.items.len(), 3);
526        assert!(page2.has_next); // More items available
527        assert!(page2.has_previous); // Can go back
528        assert!(page2.next_cursor.is_some());
529        assert!(page2.previous_cursor.is_some());
530
531        // Verify no overlap between pages
532        let page1_ids: std::collections::HashSet<_> = page1.items.iter().map(|a| a.id).collect();
533        let page2_ids: std::collections::HashSet<_> = page2.items.iter().map(|a| a.id).collect();
534        assert!(page1_ids.is_disjoint(&page2_ids));
535
536        // Test 4: Navigate backward using previous cursor
537        let prev_cursor = page2.previous_cursor.unwrap();
538        let request = DriftAlertPaginationRequest {
539            uid: UID.to_string(),
540            active: Some(true),
541            limit: Some(3),
542            cursor_created_at: Some(prev_cursor.created_at),
543            cursor_id: Some(prev_cursor.id as i32),
544            direction: Some("previous".to_string()),
545            start_datetime: None,
546            end_datetime: None,
547        };
548
549        let page_back = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
550            .await
551            .unwrap();
552
553        assert_eq!(page_back.items.len(), 3);
554        assert!(page_back.has_next); // Can go forward
555                                     // Should return to page1 items
556        assert_eq!(
557            page_back.items.iter().map(|a| a.id).collect::<Vec<_>>(),
558            page1.items.iter().map(|a| a.id).collect::<Vec<_>>()
559        );
560
561        // Test 5: Filter by active status
562        // Deactivate some alerts first
563        let to_deactivate = &page1.items[0];
564        let update_request = UpdateAlertStatus {
565            id: to_deactivate.id,
566            active: false,
567            space: "test_space".to_string(),
568        };
569
570        PostgresClient::update_drift_alert_status(&pool, &update_request)
571            .await
572            .unwrap();
573
574        // Query only active alerts
575        let request = DriftAlertPaginationRequest {
576            uid: UID.to_string(),
577            active: Some(true),
578            limit: Some(50),
579            ..Default::default()
580        };
581
582        let active_alerts = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
583            .await
584            .unwrap();
585
586        assert_eq!(active_alerts.items.len(), 9); // 10 - 1 deactivated
587        assert!(active_alerts.items.iter().all(|a| a.active));
588
589        // Query only inactive alerts
590        let request = DriftAlertPaginationRequest {
591            uid: UID.to_string(),
592            active: Some(false),
593            limit: Some(50),
594            ..Default::default()
595        };
596
597        let inactive_alerts =
598            PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
599                .await
600                .unwrap();
601
602        assert_eq!(inactive_alerts.items.len(), 1);
603        assert!(!inactive_alerts.items[0].active);
604
605        // Test 6: Query all alerts (active and inactive)
606        let request = DriftAlertPaginationRequest {
607            uid: UID.to_string(),
608            active: None, // No filter
609            limit: Some(50),
610            ..Default::default()
611        };
612
613        let all_alerts = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
614            .await
615            .unwrap();
616
617        assert_eq!(all_alerts.items.len(), 10);
618
619        // Test 7: Empty result set
620        let request = DriftAlertPaginationRequest {
621            uid: UID.to_string(),
622            active: Some(true),
623            limit: Some(3),
624            ..Default::default()
625        };
626
627        let non_existent = PostgresClient::get_paginated_drift_alerts(&pool, &request, &999999)
628            .await
629            .unwrap();
630
631        assert_eq!(non_existent.items.len(), 0);
632        assert!(!non_existent.has_next);
633        assert!(!non_existent.has_previous);
634        assert!(non_existent.next_cursor.is_none());
635        assert!(non_existent.previous_cursor.is_none());
636    }
637
638    #[tokio::test]
639    async fn test_postgres_spc_drift_record() {
640        let pool = db_pool().await;
641
642        let record1 = SpcRecord {
643            created_at: Utc::now(),
644            uid: UID.to_string(),
645            feature: "test".to_string(),
646            value: 1.0,
647            entity_id: 0,
648        };
649
650        let record2 = SpcRecord {
651            created_at: Utc::now(),
652            uid: UID.to_string(),
653            feature: "test2".to_string(),
654            value: 2.0,
655            entity_id: 0,
656        };
657
658        let result =
659            PostgresClient::insert_spc_drift_records_batch(&pool, &[record1, record2], &ENTITY_ID)
660                .await
661                .unwrap();
662
663        assert_eq!(result.rows_affected(), 2);
664    }
665
666    #[tokio::test]
667    async fn test_postgres_bin_count() {
668        let pool = db_pool().await;
669
670        let record1 = PsiRecord {
671            created_at: Utc::now(),
672            uid: UID.to_string(),
673            feature: "test".to_string(),
674            bin_id: 1,
675            bin_count: 1,
676            entity_id: ENTITY_ID,
677        };
678
679        let record2 = PsiRecord {
680            created_at: Utc::now(),
681            uid: UID.to_string(),
682            feature: "test2".to_string(),
683            bin_id: 2,
684            bin_count: 2,
685            entity_id: ENTITY_ID,
686        };
687
688        let result =
689            PostgresClient::insert_bin_counts_batch(&pool, &[record1, record2], &ENTITY_ID)
690                .await
691                .unwrap();
692
693        assert_eq!(result.rows_affected(), 2);
694    }
695
696    #[tokio::test]
697    async fn test_postgres_observability_record() {
698        let pool = db_pool().await;
699
700        let record = ObservabilityMetrics::default();
701
702        let result = PostgresClient::insert_observability_record(&pool, &record, &ENTITY_ID)
703            .await
704            .unwrap();
705
706        assert_eq!(result.rows_affected(), 1);
707    }
708
709    #[tokio::test]
710    async fn test_postgres_crud_drift_profile() {
711        let pool = db_pool().await;
712
713        let mut spc_profile = SpcDriftProfile::default();
714        let profile = DriftProfile::Spc(spc_profile.clone());
715        let uid = insert_profile_to_db(&pool, &profile, false, false).await;
716        assert!(!uid.is_empty());
717
718        let entity_id = PostgresClient::get_entity_id_from_uid(&pool, &uid)
719            .await
720            .unwrap();
721
722        spc_profile.scouter_version = "test".to_string();
723
724        let result = PostgresClient::update_drift_profile(
725            &pool,
726            &DriftProfile::Spc(spc_profile.clone()),
727            &entity_id,
728        )
729        .await
730        .unwrap();
731
732        assert_eq!(result.rows_affected(), 1);
733
734        let profile = PostgresClient::get_drift_profile(&pool, &entity_id)
735            .await
736            .unwrap();
737
738        let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
739
740        assert_eq!(deserialized, spc_profile);
741
742        PostgresClient::update_drift_profile_status(
743            &pool,
744            &ProfileStatusRequest {
745                name: spc_profile.config.name.clone(),
746                space: spc_profile.config.space.clone(),
747                version: spc_profile.config.version.clone(),
748                active: false,
749                drift_type: Some(DriftType::Spc),
750                deactivate_others: false,
751            },
752        )
753        .await
754        .unwrap();
755    }
756
757    #[tokio::test]
758    async fn test_postgres_get_features() {
759        let pool = db_pool().await;
760
761        let timestamp = Utc::now();
762
763        for _ in 0..10 {
764            let mut records = Vec::new();
765            for j in 0..10 {
766                let record = SpcRecord {
767                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
768                    uid: UID.to_string(),
769                    feature: format!("test{j}"),
770                    value: j as f64,
771                    entity_id: ENTITY_ID,
772                };
773
774                records.push(record);
775            }
776
777            let result =
778                PostgresClient::insert_spc_drift_records_batch(&pool, &records, &ENTITY_ID)
779                    .await
780                    .unwrap();
781            assert_eq!(result.rows_affected(), records.len() as u64);
782        }
783
784        let features = PostgresClient::get_spc_features(&pool, &ENTITY_ID)
785            .await
786            .unwrap();
787        assert_eq!(features.len(), 10);
788
789        let records =
790            PostgresClient::get_spc_drift_records(&pool, &timestamp, &features, &ENTITY_ID)
791                .await
792                .unwrap();
793
794        assert_eq!(records.features.len(), 10);
795
796        let binned_records = PostgresClient::get_binned_spc_drift_records(
797            &pool,
798            &DriftRequest {
799                uid: UID.to_string(),
800                time_interval: TimeInterval::FifteenMinutes,
801                max_data_points: 10,
802                ..Default::default()
803            },
804            &DatabaseSettings::default().retention_period,
805            &ObjectStorageSettings::default(),
806            &ENTITY_ID,
807        )
808        .await
809        .unwrap();
810
811        assert_eq!(binned_records.features.len(), 10);
812    }
813
814    #[tokio::test]
815    async fn test_postgres_bin_proportions() {
816        let pool = db_pool().await;
817
818        let timestamp = Utc::now();
819
820        let num_features = 3;
821        let num_bins = 5;
822
823        let features = (0..=num_features)
824            .map(|feature| {
825                let bins = (0..=num_bins)
826                    .map(|bind_id| Bin {
827                        id: bind_id,
828                        lower_limit: None,
829                        upper_limit: None,
830                        proportion: 0.0,
831                    })
832                    .collect();
833                let feature_name = format!("feature{feature}");
834                let feature_profile = PsiFeatureDriftProfile {
835                    id: feature_name.clone(),
836                    bins,
837                    timestamp,
838                    bin_type: BinType::Numeric,
839                };
840                (feature_name, feature_profile)
841            })
842            .collect();
843
844        let profile = &DriftProfile::Psi(psi::PsiDriftProfile::new(
845            features,
846            PsiDriftConfig {
847                space: SPACE.to_string(),
848                name: NAME.to_string(),
849                version: VERSION.to_string(),
850                ..Default::default()
851            },
852        ));
853        let uid = insert_profile_to_db(&pool, profile, false, false).await;
854        let entity_id = PostgresClient::get_entity_id_from_uid(&pool, &uid)
855            .await
856            .unwrap();
857
858        for feature in 0..num_features {
859            for bin in 0..=num_bins {
860                let mut records = Vec::new();
861                for j in 0..=100 {
862                    let record = PsiRecord {
863                        created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
864                        uid: uid.to_string(),
865                        feature: format!("feature{feature}"),
866                        bin_id: bin,
867                        bin_count: rand::rng().random_range(0..10) as i32,
868                        entity_id: ENTITY_ID,
869                    };
870
871                    records.push(record);
872                }
873                PostgresClient::insert_bin_counts_batch(&pool, &records, &entity_id)
874                    .await
875                    .unwrap();
876            }
877        }
878
879        let binned_records = PostgresClient::get_feature_distributions(
880            &pool,
881            &timestamp,
882            &["feature0".to_string()],
883            &entity_id,
884        )
885        .await
886        .unwrap();
887
888        // assert binned_records.features["test"]["decile_1"] is around .5
889        let bin_proportion = binned_records
890            .distributions
891            .get("feature0")
892            .unwrap()
893            .bins
894            .get(&1)
895            .unwrap();
896
897        assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
898
899        let binned_records = PostgresClient::get_binned_psi_drift_records(
900            &pool,
901            &DriftRequest {
902                uid: UID.to_string(),
903                time_interval: TimeInterval::OneHour,
904                max_data_points: 1000,
905                ..Default::default()
906            },
907            &DatabaseSettings::default().retention_period,
908            &ObjectStorageSettings::default(),
909            &entity_id,
910        )
911        .await
912        .unwrap();
913        //
914        assert_eq!(binned_records.len(), 3);
915    }
916
917    #[tokio::test]
918    async fn test_postgres_cru_custom_metric() {
919        let pool = db_pool().await;
920        let timestamp = Utc::now();
921
922        let (uid, entity_id) = PostgresClient::create_entity(
923            &pool,
924            SPACE,
925            NAME,
926            VERSION,
927            DriftType::Custom.to_string(),
928        )
929        .await
930        .unwrap();
931
932        for i in 0..2 {
933            let mut records = Vec::new();
934            for j in 0..25 {
935                let record = CustomMetricRecord {
936                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
937                    uid: uid.clone(),
938                    metric: format!("metric{i}"),
939                    value: rand::rng().random_range(0..10) as f64,
940                    entity_id: ENTITY_ID,
941                };
942                records.push(record);
943            }
944            let result =
945                PostgresClient::insert_custom_metric_values_batch(&pool, &records, &entity_id)
946                    .await
947                    .unwrap();
948            assert_eq!(result.rows_affected(), 25);
949        }
950
951        // insert random record to test has statistics funcs handle single record
952        let record = CustomMetricRecord {
953            created_at: Utc::now(),
954            uid: uid.clone(),
955            metric: "metric3".to_string(),
956            value: rand::rng().random_range(0..10) as f64,
957            entity_id: ENTITY_ID,
958        };
959
960        let result =
961            PostgresClient::insert_custom_metric_values_batch(&pool, &[record], &entity_id)
962                .await
963                .unwrap();
964        assert_eq!(result.rows_affected(), 1);
965
966        let metrics = PostgresClient::get_custom_metric_values(
967            &pool,
968            &timestamp,
969            &["metric1".to_string()],
970            &entity_id,
971        )
972        .await
973        .unwrap();
974
975        assert_eq!(metrics.len(), 1);
976
977        let binned_records = PostgresClient::get_binned_custom_drift_records(
978            &pool,
979            &DriftRequest {
980                uid: uid.clone(),
981                time_interval: TimeInterval::OneHour,
982                max_data_points: 1000,
983                ..Default::default()
984            },
985            &DatabaseSettings::default().retention_period,
986            &ObjectStorageSettings::default(),
987            &entity_id,
988        )
989        .await
990        .unwrap();
991        //
992        assert_eq!(binned_records.metrics.len(), 3);
993    }
994
995    #[tokio::test]
996    async fn test_postgres_user() {
997        let pool = db_pool().await;
998        let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
999
1000        // Create
1001        let user = User::new(
1002            "user".to_string(),
1003            "pass".to_string(),
1004            "email".to_string(),
1005            recovery_codes,
1006            None,
1007            None,
1008            None,
1009            None,
1010        );
1011        PostgresClient::insert_user(&pool, &user).await.unwrap();
1012
1013        // Read
1014        let mut user = PostgresClient::get_user(&pool, "user")
1015            .await
1016            .unwrap()
1017            .unwrap();
1018
1019        assert_eq!(user.username, "user");
1020        assert_eq!(user.group_permissions, vec!["user"]);
1021        assert_eq!(user.email, "email");
1022
1023        // update user
1024        user.active = false;
1025        user.refresh_token = Some("token".to_string());
1026
1027        // Update
1028        PostgresClient::update_user(&pool, &user).await.unwrap();
1029        let user = PostgresClient::get_user(&pool, "user")
1030            .await
1031            .unwrap()
1032            .unwrap();
1033        assert!(!user.active);
1034        assert_eq!(user.refresh_token.unwrap(), "token");
1035
1036        // get users
1037        let users = PostgresClient::get_users(&pool).await.unwrap();
1038        assert_eq!(users.len(), 1);
1039
1040        // get last admin
1041        let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
1042        assert!(!is_last_admin);
1043
1044        // delete
1045        PostgresClient::delete_user(&pool, "user").await.unwrap();
1046    }
1047
1048    #[tokio::test]
1049    async fn test_postgres_genai_eval_record_insert_get() {
1050        let pool = db_pool().await;
1051
1052        let (uid, entity_id) = PostgresClient::create_entity(
1053            &pool,
1054            SPACE,
1055            NAME,
1056            VERSION,
1057            DriftType::GenAI.to_string(),
1058        )
1059        .await
1060        .unwrap();
1061
1062        let input = "This is a test input";
1063        let output = "This is a test response";
1064
1065        for j in 0..10 {
1066            let context = serde_json::json!({
1067                "input": input,
1068                "response": output,
1069            });
1070            let record = GenAIEvalRecord {
1071                created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1072                context,
1073                status: Status::Pending,
1074                id: 0, // This will be set by the database
1075                uid: format!("test_{}", j),
1076                entity_uid: uid.clone(),
1077                entity_id,
1078                ..Default::default()
1079            };
1080
1081            let boxed = BoxedGenAIEvalRecord::new(record);
1082
1083            let result = PostgresClient::insert_genai_eval_record(&pool, boxed, &entity_id)
1084                .await
1085                .unwrap();
1086
1087            assert_eq!(result.rows_affected(), 1);
1088        }
1089
1090        let features = PostgresClient::get_genai_eval_records(&pool, None, None, &entity_id)
1091            .await
1092            .unwrap();
1093        assert_eq!(features.len(), 10);
1094
1095        // get pending task
1096        let pending_tasks = PostgresClient::get_pending_genai_eval_record(&pool)
1097            .await
1098            .unwrap();
1099
1100        // assert not empty
1101        assert!(pending_tasks.is_some());
1102
1103        // get pending task with space, name, version
1104        let task_input = &pending_tasks.as_ref().unwrap().context["input"];
1105        assert_eq!(*task_input, "This is a test input".to_string());
1106
1107        // update pending task
1108        PostgresClient::update_genai_eval_record_status(
1109            &pool,
1110            &pending_tasks.unwrap(),
1111            Status::Processed,
1112            &(1_i64),
1113        )
1114        .await
1115        .unwrap();
1116
1117        // query processed tasks
1118        let processed_tasks = PostgresClient::get_genai_eval_records(
1119            &pool,
1120            None,
1121            Some(Status::Processed),
1122            &entity_id,
1123        )
1124        .await
1125        .unwrap();
1126
1127        // assert not empty
1128        assert_eq!(processed_tasks.len(), 1);
1129    }
1130
1131    #[tokio::test]
1132    async fn test_postgres_genai_eval_record_pagination() {
1133        let pool = db_pool().await;
1134
1135        let (uid, entity_id) = PostgresClient::create_entity(
1136            &pool,
1137            SPACE,
1138            NAME,
1139            VERSION,
1140            DriftType::GenAI.to_string(),
1141        )
1142        .await
1143        .unwrap();
1144
1145        let input = "This is a test input";
1146        let output = "This is a test response";
1147
1148        // Insert 10 records with increasing timestamps
1149        for j in 0..10 {
1150            let context = serde_json::json!({
1151                "input": input,
1152                "response": output,
1153            });
1154            let record = GenAIEvalRecord {
1155                created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1156                context,
1157                status: Status::Pending,
1158                id: 0, // This will be set by the database
1159                uid: format!("test_{}", j),
1160                entity_uid: uid.clone(),
1161                entity_id: ENTITY_ID,
1162                ..Default::default()
1163            };
1164
1165            let boxed = BoxedGenAIEvalRecord::new(record);
1166
1167            let result = PostgresClient::insert_genai_eval_record(&pool, boxed, &entity_id)
1168                .await
1169                .unwrap();
1170
1171            assert_eq!(result.rows_affected(), 1);
1172        }
1173
1174        // ===== PAGE 1: Get first 5 records (newest) =====
1175        let params = GenAIEvalRecordPaginationRequest {
1176            status: None,
1177            limit: Some(5),
1178            cursor_created_at: None,
1179            cursor_id: None,
1180            direction: None,
1181            ..Default::default()
1182        };
1183
1184        let page1 = PostgresClient::get_paginated_genai_eval_records(&pool, &params, &entity_id)
1185            .await
1186            .unwrap();
1187
1188        assert_eq!(page1.items.len(), 5, "Page 1 should have 5 records");
1189        assert!(page1.has_next, "Should have next page");
1190        assert!(
1191            !page1.has_previous,
1192            "Should not have previous page (first page)"
1193        );
1194        assert!(page1.next_cursor.is_some(), "Should have next cursor");
1195
1196        // First item should be the NEWEST record (highest ID)
1197        let page1_first = page1.items.first().unwrap();
1198        let page1_last = page1.items.last().unwrap();
1199
1200        assert!(
1201            page1_first.created_at >= page1_last.created_at,
1202            "Page 1 should be sorted newest first (DESC)"
1203        );
1204
1205        // ===== PAGE 2: Get next 5 records (older) =====
1206        let next_cursor = page1.next_cursor.unwrap();
1207
1208        let params = GenAIEvalRecordPaginationRequest {
1209            status: None,
1210            limit: Some(5),
1211            cursor_created_at: Some(next_cursor.created_at),
1212            cursor_id: Some(next_cursor.id),
1213            direction: None,
1214            ..Default::default()
1215        };
1216
1217        let page2 = PostgresClient::get_paginated_genai_eval_records(&pool, &params, &entity_id)
1218            .await
1219            .unwrap();
1220
1221        assert_eq!(page2.items.len(), 5, "Page 2 should have 5 records");
1222        assert!(!page2.has_next, "Should not have next page (last page)");
1223        assert!(page2.has_previous, "Should have previous page");
1224        assert!(
1225            page2.previous_cursor.is_some(),
1226            "Should have previous cursor"
1227        );
1228
1229        let page2_first = page2.items.first().unwrap();
1230
1231        // Page 2 first item should be OLDER than Page 1 last item
1232        assert!(
1233            page2_first.created_at < page1_last.created_at
1234                || (page2_first.created_at == page1_last.created_at
1235                    && page2_first.id < page1_last.id),
1236            "Page 2 should start with records older than Page 1 last item"
1237        );
1238
1239        // Verify we got all 10 records across both pages
1240        let all_ids: Vec<i64> = page1
1241            .items
1242            .iter()
1243            .chain(page2.items.iter())
1244            .map(|r| r.id)
1245            .collect();
1246
1247        assert_eq!(all_ids.len(), 10, "Should have 10 unique records total");
1248
1249        // All IDs should be unique
1250        let unique_ids: std::collections::HashSet<_> = all_ids.iter().collect();
1251        assert_eq!(unique_ids.len(), 10, "All IDs should be unique");
1252
1253        // ===== BACKWARD PAGINATION TEST =====
1254        // Go back from page 2 to page 1
1255        let previous_cursor = page2.previous_cursor.unwrap();
1256
1257        let params = GenAIEvalRecordPaginationRequest {
1258            status: None,
1259            limit: Some(5),
1260            cursor_created_at: Some(previous_cursor.created_at),
1261            cursor_id: Some(previous_cursor.id),
1262            direction: Some("previous".to_string()),
1263            ..Default::default()
1264        };
1265
1266        let page1_again =
1267            PostgresClient::get_paginated_genai_eval_records(&pool, &params, &entity_id)
1268                .await
1269                .unwrap();
1270
1271        assert_eq!(
1272            page1_again.items.len(),
1273            5,
1274            "Going back should return 5 records"
1275        );
1276
1277        // Should get the same records as page 1
1278        assert_eq!(
1279            page1_again.items.first().unwrap().id,
1280            page1_first.id,
1281            "Should return to the same first record"
1282        );
1283    }
1284
1285    #[tokio::test]
1286    async fn test_postgres_genai_eval_workflow_pagination() {
1287        let pool = db_pool().await;
1288
1289        let (_uid, entity_id) = PostgresClient::create_entity(
1290            &pool,
1291            SPACE,
1292            NAME,
1293            VERSION,
1294            DriftType::GenAI.to_string(),
1295        )
1296        .await
1297        .unwrap();
1298
1299        // Insert 10 records with increasing timestamps
1300        for j in 0..10 {
1301            let record = GenAIEvalWorkflowResult {
1302                created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1303                record_uid: format!("test_{}", j),
1304                entity_id,
1305                ..Default::default()
1306            };
1307
1308            let result =
1309                PostgresClient::insert_genai_eval_workflow_record(&pool, &record, &entity_id)
1310                    .await
1311                    .unwrap();
1312
1313            assert_eq!(result.rows_affected(), 1);
1314        }
1315
1316        // ===== PAGE 1: Get first 5 records (newest) =====
1317        let params = GenAIEvalRecordPaginationRequest {
1318            status: None,
1319            limit: Some(5),
1320            cursor_created_at: None,
1321            cursor_id: None,
1322            direction: None,
1323            ..Default::default()
1324        };
1325
1326        let page1 =
1327            PostgresClient::get_paginated_genai_eval_workflow_records(&pool, &params, &entity_id)
1328                .await
1329                .unwrap();
1330
1331        assert_eq!(page1.items.len(), 5, "Page 1 should have 5 records");
1332        assert!(page1.has_next, "Should have next page");
1333        assert!(
1334            !page1.has_previous,
1335            "Should not have previous page (first page)"
1336        );
1337        assert!(page1.next_cursor.is_some(), "Should have next cursor");
1338
1339        // First item should be the NEWEST record (highest ID)
1340        let page1_first = page1.items.first().unwrap();
1341        let page1_last = page1.items.last().unwrap();
1342
1343        assert!(
1344            page1_first.created_at >= page1_last.created_at,
1345            "Page 1 should be sorted newest first (DESC)"
1346        );
1347
1348        // ===== PAGE 2: Get next 5 records (older) =====
1349        let next_cursor = page1.next_cursor.unwrap();
1350
1351        let params = GenAIEvalRecordPaginationRequest {
1352            status: None,
1353            limit: Some(5),
1354            cursor_created_at: Some(next_cursor.created_at),
1355            cursor_id: Some(next_cursor.id),
1356            direction: None,
1357            ..Default::default()
1358        };
1359
1360        let page2 =
1361            PostgresClient::get_paginated_genai_eval_workflow_records(&pool, &params, &entity_id)
1362                .await
1363                .unwrap();
1364
1365        assert_eq!(page2.items.len(), 5, "Page 2 should have 5 records");
1366        assert!(!page2.has_next, "Should not have next page (last page)");
1367        assert!(page2.has_previous, "Should have previous page");
1368        assert!(
1369            page2.previous_cursor.is_some(),
1370            "Should have previous cursor"
1371        );
1372
1373        let page2_first = page2.items.first().unwrap();
1374
1375        // Page 2 first item should be OLDER than Page 1 last item
1376        assert!(
1377            page2_first.created_at < page1_last.created_at
1378                || (page2_first.created_at == page1_last.created_at
1379                    && page2_first.id < page1_last.id),
1380            "Page 2 should start with records older than Page 1 last item"
1381        );
1382
1383        // Verify we got all 10 records across both pages
1384        let all_ids: Vec<i64> = page1
1385            .items
1386            .iter()
1387            .chain(page2.items.iter())
1388            .map(|r| r.id)
1389            .collect();
1390
1391        assert_eq!(all_ids.len(), 10, "Should have 10 unique records total");
1392
1393        // All IDs should be unique
1394        let unique_ids: std::collections::HashSet<_> = all_ids.iter().collect();
1395        assert_eq!(unique_ids.len(), 10, "All IDs should be unique");
1396
1397        // ===== BACKWARD PAGINATION TEST =====
1398        // Go back from page 2 to page 1
1399        let previous_cursor = page2.previous_cursor.unwrap();
1400
1401        let params = GenAIEvalRecordPaginationRequest {
1402            status: None,
1403            limit: Some(5),
1404            cursor_created_at: Some(previous_cursor.created_at),
1405            cursor_id: Some(previous_cursor.id),
1406            direction: Some("previous".to_string()),
1407            ..Default::default()
1408        };
1409
1410        let page1_again =
1411            PostgresClient::get_paginated_genai_eval_workflow_records(&pool, &params, &entity_id)
1412                .await
1413                .unwrap();
1414
1415        assert_eq!(
1416            page1_again.items.len(),
1417            5,
1418            "Going back should return 5 records"
1419        );
1420
1421        // Should get the same records as page 1
1422        assert_eq!(
1423            page1_again.items.first().unwrap().id,
1424            page1_first.id,
1425            "Should return to the same first record"
1426        );
1427    }
1428
1429    #[tokio::test]
1430    async fn test_postgres_genai_task_result_insert_get() {
1431        let pool = db_pool().await;
1432
1433        let timestamp = Utc::now();
1434
1435        let (uid, entity_id) = PostgresClient::create_entity(
1436            &pool,
1437            SPACE,
1438            NAME,
1439            VERSION,
1440            DriftType::GenAI.to_string(),
1441        )
1442        .await
1443        .unwrap();
1444
1445        let mut records = Vec::new();
1446        for i in 0..2 {
1447            for j in 0..25 {
1448                let record = GenAIEvalTaskResult {
1449                    record_uid: format!("record_uid_{i}_{j}"),
1450                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1451                    start_time: Utc::now(),
1452                    end_time: Utc::now() + chrono::Duration::seconds(1),
1453                    entity_id,
1454                    task_id: format!("task{i}"),
1455                    task_type: scouter_types::genai::EvaluationTaskType::Assertion,
1456                    passed: true,
1457                    value: j as f64,
1458                    field_path: Some(format!("field.path.{i}")),
1459                    operator: scouter_types::genai::ComparisonOperator::Contains,
1460                    expected: Value::Null,
1461                    actual: Value::Null,
1462                    message: "All good".to_string(),
1463                    entity_uid: uid.clone(),
1464                    condition: false,
1465                    stage: 0_i32,
1466                };
1467                records.push(record);
1468            }
1469            let result =
1470                PostgresClient::insert_eval_task_results_batch(&pool, &records, &entity_id)
1471                    .await
1472                    .unwrap();
1473            assert_eq!(result.rows_affected(), 25);
1474        }
1475
1476        let metrics = PostgresClient::get_genai_task_values(
1477            &pool,
1478            &timestamp,
1479            &["task1".to_string()],
1480            &entity_id,
1481        )
1482        .await
1483        .unwrap();
1484
1485        assert_eq!(metrics.len(), 1);
1486        let binned_records = PostgresClient::get_binned_genai_task_values(
1487            &pool,
1488            &DriftRequest {
1489                uid: uid.clone(),
1490                time_interval: TimeInterval::OneHour,
1491                max_data_points: 1000,
1492                ..Default::default()
1493            },
1494            &DatabaseSettings::default().retention_period,
1495            &ObjectStorageSettings::default(),
1496            &entity_id,
1497        )
1498        .await
1499        .unwrap();
1500        //
1501        assert_eq!(binned_records.metrics.len(), 2);
1502
1503        let eval_task = PostgresClient::get_genai_eval_task(&pool, &records[0].record_uid)
1504            .await
1505            .unwrap();
1506
1507        assert_eq!(eval_task[0].record_uid, records[0].record_uid);
1508    }
1509
1510    #[tokio::test]
1511    async fn test_postgres_genai_workflow_result_insert_get() {
1512        let pool = db_pool().await;
1513
1514        let timestamp = Utc::now();
1515
1516        let (uid, entity_id) = PostgresClient::create_entity(
1517            &pool,
1518            SPACE,
1519            NAME,
1520            VERSION,
1521            DriftType::GenAI.to_string(),
1522        )
1523        .await
1524        .unwrap();
1525
1526        for i in 0..2 {
1527            for j in 0..25 {
1528                let record = GenAIEvalWorkflowResult {
1529                    record_uid: format!("record_uid_{i}_{j}"),
1530                    created_at: Utc::now() + chrono::Duration::hours(i),
1531                    entity_id,
1532                    total_tasks: 10,
1533                    passed_tasks: 8,
1534                    failed_tasks: 2,
1535                    pass_rate: 0.8,
1536                    duration_ms: 1500,
1537                    entity_uid: uid.clone(),
1538                    execution_plan: ExecutionPlan::default(),
1539                    id: 0,
1540                };
1541                let result =
1542                    PostgresClient::insert_genai_eval_workflow_record(&pool, &record, &entity_id)
1543                        .await
1544                        .unwrap();
1545                assert_eq!(result.rows_affected(), 1);
1546            }
1547        }
1548
1549        let metric = PostgresClient::get_genai_workflow_value(&pool, &timestamp, &entity_id)
1550            .await
1551            .unwrap();
1552
1553        assert!(metric.is_some());
1554
1555        let binned_records = PostgresClient::get_binned_genai_workflow_values(
1556            &pool,
1557            &DriftRequest {
1558                uid: uid.clone(),
1559                time_interval: TimeInterval::OneHour,
1560                max_data_points: 1000,
1561                ..Default::default()
1562            },
1563            &DatabaseSettings::default().retention_period,
1564            &ObjectStorageSettings::default(),
1565            &entity_id,
1566        )
1567        .await
1568        .unwrap();
1569        //
1570        assert_eq!(binned_records.metrics.len(), 1);
1571    }
1572
1573    #[tokio::test]
1574    async fn test_postgres_tracing_metrics() {
1575        let pool = db_pool().await;
1576        let script = std::fs::read_to_string("src/tests/script/populate_trace.sql").unwrap();
1577        sqlx::query(&script).execute(&pool).await.unwrap();
1578        let mut filters = TraceFilters::default();
1579
1580        let first_batch = PostgresClient::get_paginated_traces(&pool, filters.clone())
1581            .await
1582            .unwrap();
1583
1584        assert_eq!(
1585            first_batch.items.len(),
1586            50,
1587            "First batch should have 50 records"
1588        );
1589
1590        // test pagination (get last record created_at and trace_id)
1591        let last_record = first_batch.next_cursor.unwrap();
1592        filters = filters.next_page(&last_record);
1593
1594        let next_batch = PostgresClient::get_paginated_traces(&pool, filters.clone())
1595            .await
1596            .unwrap();
1597
1598        // should be another 50 records
1599        assert_eq!(
1600            next_batch.items.len(),
1601            50,
1602            "Next batch should have 50 records"
1603        );
1604
1605        // assert next_batch first record timestamp is <= first_batch last record timestamp
1606        let next_first_record = next_batch.items.first().unwrap();
1607        assert!(
1608            next_first_record.start_time <= last_record.start_time,
1609            "Next batch first record timestamp is not less than or equal to last record timestamp"
1610        );
1611
1612        // test pagination for previous
1613        filters = filters.previous_page(&next_batch.previous_cursor.unwrap());
1614        let previous_batch = PostgresClient::get_paginated_traces(&pool, filters.clone())
1615            .await
1616            .unwrap();
1617        assert_eq!(
1618            previous_batch.items.len(),
1619            50,
1620            "Previous batch should have 50 records"
1621        );
1622
1623        // Filter records to find item with >5 spans
1624        let filtered_record = first_batch
1625            .items
1626            .iter()
1627            .find(|record| record.span_count > 5)
1628            .unwrap();
1629
1630        filters.cursor_start_time = None;
1631        filters.cursor_trace_id = None;
1632
1633        let records = PostgresClient::get_paginated_traces(&pool, filters.clone())
1634            .await
1635            .unwrap();
1636
1637        // Records are randomly generated, so just assert we get some records back
1638        assert!(
1639            !records.items.is_empty(),
1640            "Should return records with specified filters"
1641        );
1642
1643        // get spans for filtered trace
1644        let spans = PostgresClient::get_trace_spans(
1645            &pool,
1646            &filtered_record.trace_id,
1647            Some(&filtered_record.service_name),
1648        )
1649        .await
1650        .unwrap();
1651
1652        assert!(spans.len() == filtered_record.span_count as usize);
1653
1654        let start_time = filtered_record.start_time - chrono::Duration::hours(48);
1655        let end_time = filtered_record.start_time + chrono::Duration::minutes(5);
1656
1657        // make request for trace metrics
1658        let trace_metrics =
1659            PostgresClient::get_trace_metrics(&pool, None, start_time, end_time, "5 minutes", None)
1660                .await
1661                .unwrap();
1662
1663        // assert we have data points
1664        assert!(trace_metrics.len() >= 10);
1665
1666        // get paginated traces with tags
1667        let filters = scouter_types::sql::TraceFilters {
1668            attribute_filters: Some(vec![("component=kafka".to_string())]),
1669            ..Default::default()
1670        };
1671        let tagged_batch = PostgresClient::get_paginated_traces(&pool, filters)
1672            .await
1673            .unwrap();
1674        assert!(!tagged_batch.items.is_empty());
1675    }
1676
1677    #[tokio::test]
1678    async fn test_postgres_tracing_insert() {
1679        let pool = db_pool().await;
1680
1681        // create parent trace
1682        let mut trace_record = random_trace_record();
1683        let trace_id = trace_record.trace_id.clone();
1684
1685        // create spans
1686        let root_span = random_span_record(&trace_id, None, &trace_record.service_name);
1687        let child_span =
1688            random_span_record(&trace_id, Some(&root_span.span_id), &root_span.service_name);
1689
1690        // set root span id in trace record
1691        trace_record.root_span_id = root_span.span_id.clone();
1692
1693        // insert spans
1694        let result =
1695            PostgresClient::insert_span_batch(&pool, &[root_span.clone(), child_span.clone()])
1696                .await
1697                .unwrap();
1698
1699        assert_eq!(result.rows_affected(), 2);
1700
1701        let inserted_created_at = trace_record.created_at;
1702        let inserted_trace_id = trace_record.trace_id.clone();
1703
1704        let trace_filter = TraceFilters {
1705            cursor_start_time: Some(inserted_created_at + Duration::days(1)),
1706            cursor_trace_id: Some(inserted_trace_id),
1707            start_time: Some(inserted_created_at - Duration::minutes(5)),
1708            end_time: Some(inserted_created_at + Duration::days(1)),
1709            ..TraceFilters::default()
1710        };
1711
1712        let traces = PostgresClient::get_paginated_traces(&pool, trace_filter)
1713            .await
1714            .unwrap();
1715
1716        assert_eq!(traces.items.len(), 1);
1717        let retrieved_trace = &traces.items[0];
1718        // assert span count is 2
1719        assert_eq!(retrieved_trace.span_count, 2);
1720
1721        let baggage = TraceBaggageRecord {
1722            created_at: Utc::now(),
1723            trace_id: trace_record.trace_id.clone(),
1724            scope: "test_scope".to_string(),
1725            key: "user_id".to_string(),
1726            value: "12345".to_string(),
1727        };
1728
1729        let result =
1730            PostgresClient::insert_trace_baggage_batch(&pool, std::slice::from_ref(&baggage))
1731                .await
1732                .unwrap();
1733
1734        assert_eq!(result.rows_affected(), 1);
1735
1736        let retrieved_baggage =
1737            PostgresClient::get_trace_baggage_records(&pool, &trace_record.trace_id)
1738                .await
1739                .unwrap();
1740
1741        assert_eq!(retrieved_baggage.len(), 1);
1742    }
1743
1744    #[tokio::test]
1745    async fn test_postgres_tags() {
1746        let pool = db_pool().await;
1747        let uid = create_uuid7();
1748
1749        let tag1 = TagRecord {
1750            entity_id: uid.clone(),
1751            entity_type: "service".to_string(),
1752            key: "env".to_string(),
1753            value: "production".to_string(),
1754        };
1755
1756        let tag2 = TagRecord {
1757            entity_id: uid.clone(),
1758            entity_type: "service".to_string(),
1759            key: "team".to_string(),
1760            value: "backend".to_string(),
1761        };
1762
1763        let result = PostgresClient::insert_tag_batch(&pool, &[tag1.clone(), tag2.clone()])
1764            .await
1765            .unwrap();
1766
1767        assert_eq!(result.rows_affected(), 2);
1768
1769        let tags = PostgresClient::get_tags(&pool, "service", &uid)
1770            .await
1771            .unwrap();
1772
1773        assert_eq!(tags.len(), 2);
1774
1775        let tag_filter = vec![Tag {
1776            key: tags.first().unwrap().key.clone(),
1777            value: tags.first().unwrap().value.clone(),
1778        }];
1779
1780        let entity_id = PostgresClient::get_entity_id_by_tags(&pool, "service", &tag_filter, false)
1781            .await
1782            .unwrap();
1783
1784        assert_eq!(entity_id.first().unwrap(), &uid);
1785    }
1786
1787    #[tokio::test]
1788    async fn test_postgres_get_spans_from_tags() {
1789        let pool = db_pool().await;
1790
1791        // create parent trace
1792        let mut trace_record = random_trace_record();
1793        let trace_id = trace_record.trace_id.clone();
1794
1795        // create spans
1796        let root_span = random_span_record(&trace_id, None, &trace_record.service_name);
1797        let child_span =
1798            random_span_record(&trace_id, Some(&root_span.span_id), &root_span.service_name);
1799
1800        // set root span id in trace record
1801        trace_record.root_span_id = root_span.span_id.clone();
1802
1803        // insert spans
1804        let result =
1805            PostgresClient::insert_span_batch(&pool, &[root_span.clone(), child_span.clone()])
1806                .await
1807                .unwrap();
1808
1809        assert_eq!(result.rows_affected(), 2);
1810
1811        let tag = TagRecord {
1812            entity_id: trace_record.trace_id.clone(),
1813            entity_type: "trace".to_string(),
1814            key: "env".to_string(),
1815            value: "production".to_string(),
1816        };
1817
1818        let result = PostgresClient::insert_tag_batch(&pool, std::slice::from_ref(&tag))
1819            .await
1820            .unwrap();
1821
1822        assert_eq!(result.rows_affected(), 1);
1823
1824        let tag_filters = vec![HashMap::from([
1825            ("key".to_string(), "env".to_string()),
1826            ("value".to_string(), "production".to_string()),
1827        ])];
1828
1829        let spans = PostgresClient::get_spans_from_tags(
1830            &pool,
1831            "trace",
1832            tag_filters,
1833            true,
1834            Some(&trace_record.service_name),
1835        )
1836        .await
1837        .unwrap();
1838
1839        assert_eq!(spans.len(), 2);
1840    }
1841}