Skip to main content

scouter_sql/sql/
postgres.rs

1use crate::sql::aggregator::init_trace_cache;
2use crate::sql::cache::entity_cache;
3use crate::sql::cache::init_entity_cache;
4use crate::sql::error::SqlError;
5use crate::sql::traits::{
6    AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, GenAIDriftSqlLogic,
7    ObservabilitySqlLogic, ProfileSqlLogic, PsiSqlLogic, SpcSqlLogic, TagSqlLogic, TraceSqlLogic,
8    UserSqlLogic,
9};
10use scouter_settings::DatabaseSettings;
11use scouter_types::{RecordType, ServerRecords, TagRecord, ToDriftRecords, TraceServerRecord};
12use sqlx::ConnectOptions;
13use sqlx::{postgres::PgConnectOptions, Pool, Postgres};
14use std::result::Result::Ok;
15use std::time::Duration;
16use tokio::try_join;
17use tracing::log::LevelFilter;
18use tracing::{debug, error, info, instrument};
19const DEFAULT_BATCH_SIZE: usize = 500;
20
21#[derive(Debug, Clone)]
22#[allow(dead_code)]
23pub struct PostgresClient {}
24
25impl SpcSqlLogic for PostgresClient {}
26impl CustomMetricSqlLogic for PostgresClient {}
27impl PsiSqlLogic for PostgresClient {}
28impl GenAIDriftSqlLogic for PostgresClient {}
29impl UserSqlLogic for PostgresClient {}
30impl ProfileSqlLogic for PostgresClient {}
31impl ObservabilitySqlLogic for PostgresClient {}
32impl AlertSqlLogic for PostgresClient {}
33impl ArchiveSqlLogic for PostgresClient {}
34impl TraceSqlLogic for PostgresClient {}
35impl TagSqlLogic for PostgresClient {}
36
37impl PostgresClient {
38    /// Setup the application with the given database pool.
39    ///
40    /// # Returns
41    ///
42    /// * `Result<Pool<Postgres>, anyhow::Error>` - Result of the database pool
43    #[instrument(skip(database_settings))]
44    pub async fn create_db_pool(
45        database_settings: &DatabaseSettings,
46    ) -> Result<Pool<Postgres>, SqlError> {
47        let mut opts: PgConnectOptions = database_settings.connection_uri.parse()?;
48
49        // Sqlx logs a lot of debug information by default, which can be overwhelming.
50        opts = opts.log_statements(LevelFilter::Off);
51
52        let pool = match sqlx::postgres::PgPoolOptions::new()
53            .max_connections(database_settings.max_connections)
54            .min_connections(database_settings.min_connections)
55            .acquire_timeout(Duration::from_secs(
56                database_settings.db_acquire_timeout_seconds,
57            ))
58            .idle_timeout(Duration::from_secs(
59                database_settings.db_idle_timeout_seconds,
60            ))
61            .max_lifetime(Duration::from_secs(
62                database_settings.db_max_lifetime_seconds,
63            ))
64            .test_before_acquire(database_settings.db_test_before_acquire)
65            .connect_with(opts)
66            .await
67        {
68            Ok(pool) => {
69                info!("✅ Successfully connected to database");
70                pool
71            }
72            Err(err) => {
73                error!("🚨 Failed to connect to database {:?}", err);
74                std::process::exit(1);
75            }
76        };
77
78        // setup entity cache
79        init_entity_cache(database_settings.entity_cache_size);
80
81        // setup trace cache
82        init_trace_cache(
83            pool.clone(),
84            database_settings.flush_interval,
85            database_settings.stale_threshold,
86            database_settings.max_cache_size,
87        )
88        .await?;
89
90        // Run migrations
91        if let Err(err) = Self::run_migrations(&pool).await {
92            error!("🚨 Failed to run migrations {:?}", err);
93            std::process::exit(1);
94        }
95
96        Ok(pool)
97    }
98
99    pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
100        info!("Running migrations");
101        sqlx::migrate!("src/migrations")
102            .run(pool)
103            .await
104            .map_err(SqlError::MigrateError)?;
105
106        debug!("Migrations complete");
107
108        Ok(())
109    }
110}
111
112pub struct MessageHandler {}
113
114impl MessageHandler {
115    #[instrument(skip_all)]
116    pub async fn insert_server_records(
117        pool: &Pool<Postgres>,
118        records: ServerRecords,
119    ) -> Result<(), SqlError> {
120        debug!("Inserting server records: {:?}", records.record_type()?);
121
122        let entity_id = entity_cache()
123            .get_entity_id_from_uid(pool, records.uid()?)
124            .await
125            .inspect_err(|e| error!("Failed to get entity ID from UID: {:?}", e))?;
126
127        match records.record_type()? {
128            RecordType::Spc => {
129                let spc_records = records.to_spc_drift_records()?;
130                debug!("SPC record count: {}", spc_records.len());
131
132                for chunk in spc_records.chunks(DEFAULT_BATCH_SIZE) {
133                    PostgresClient::insert_spc_drift_records_batch(pool, chunk, &entity_id)
134                        .await
135                        .map_err(|e| {
136                            error!("Failed to insert SPC drift records batch: {:?}", e);
137                            e
138                        })?;
139                }
140            }
141
142            RecordType::Psi => {
143                let psi_records = records.to_psi_drift_records()?;
144                debug!("PSI record count: {}", psi_records.len());
145
146                for chunk in psi_records.chunks(DEFAULT_BATCH_SIZE) {
147                    PostgresClient::insert_bin_counts_batch(pool, chunk, &entity_id)
148                        .await
149                        .map_err(|e| {
150                            error!("Failed to insert PSI drift records batch: {:?}", e);
151                            e
152                        })?;
153                }
154            }
155            RecordType::Custom => {
156                let custom_records = records.to_custom_metric_drift_records()?;
157                debug!("Custom record count: {}", custom_records.len());
158
159                for chunk in custom_records.chunks(DEFAULT_BATCH_SIZE) {
160                    PostgresClient::insert_custom_metric_values_batch(pool, chunk, &entity_id)
161                        .await
162                        .map_err(|e| {
163                            error!("Failed to insert custom metric records batch: {:?}", e);
164                            e
165                        })?;
166                }
167            }
168
169            RecordType::GenAIEval => {
170                debug!("LLM Drift record count: {:?}", records.len());
171                let records = records.to_genai_eval_records()?;
172                for record in records {
173                    let _ = PostgresClient::insert_genai_eval_record(pool, record, &entity_id)
174                        .await
175                        .map_err(|e| {
176                            error!("Failed to insert GenAI drift record: {:?}", e);
177                        });
178                }
179            }
180
181            RecordType::GenAITask => {
182                debug!("GenAI Task count: {:?}", records.len());
183                let records = records.to_genai_task_records()?;
184                for chunk in records.chunks(DEFAULT_BATCH_SIZE) {
185                    PostgresClient::insert_eval_task_results_batch(pool, chunk, &entity_id)
186                        .await
187                        .map_err(|e| {
188                            error!("Failed to insert GenAI task records batch: {:?}", e);
189                            e
190                        })?;
191                }
192            }
193
194            RecordType::GenAIWorkflow => {
195                debug!("GenAI Workflow count: {:?}", records.len());
196                let records = records.to_genai_workflow_records()?;
197                for record in records {
198                    let _ = PostgresClient::insert_genai_eval_workflow_record(
199                        pool, &record, &entity_id,
200                    )
201                    .await
202                    .map_err(|e| {
203                        error!("Failed to insert GenAI workflow record: {:?}", e);
204                    });
205                }
206            }
207
208            _ => {
209                error!(
210                    "Unsupported record type for batch insert: {:?}",
211                    records.record_type()?
212                );
213                return Err(SqlError::UnsupportedBatchTypeError);
214            }
215        }
216
217        Ok(())
218    }
219
220    pub async fn insert_trace_server_record(
221        pool: &Pool<Postgres>,
222        records: TraceServerRecord,
223    ) -> Result<(), SqlError> {
224        let (span_batch, baggage_batch, tag_records) = records.to_records()?;
225
226        let (span_result, baggage_result, tag_result) = try_join!(
227            PostgresClient::insert_span_batch(pool, &span_batch),
228            PostgresClient::insert_trace_baggage_batch(pool, &baggage_batch),
229            PostgresClient::insert_tag_batch(pool, &tag_records),
230        )?;
231
232        debug!(
233            span_rows = span_result.rows_affected(),
234            baggage_rows = baggage_result.rows_affected(),
235            total_spans = span_batch.len(),
236            total_baggage = baggage_batch.len(),
237            tag_rows = tag_result.rows_affected(),
238            "Successfully inserted trace server records"
239        );
240        Ok(())
241    }
242
243    pub async fn insert_tag_record(
244        pool: &Pool<Postgres>,
245        record: TagRecord,
246    ) -> Result<(), SqlError> {
247        let result = PostgresClient::insert_tag_batch(pool, std::slice::from_ref(&record)).await?;
248
249        debug!(
250            rows_affected = result.rows_affected(),
251            entity_type = record.entity_type.as_str(),
252            entity_id = record.entity_id.as_str(),
253            key = record.key.as_str(),
254            "Successfully inserted tag record"
255        );
256
257        Ok(())
258    }
259}
260
261/// Runs database integratino tests
262/// Note - binned queries targeting custom intervals with long-term and short-term data are
263/// done in the scouter-server integration tests
264#[cfg(test)]
265mod tests {
266
267    use std::collections::HashMap;
268
269    use super::*;
270    use crate::sql::aggregator::shutdown_trace_cache;
271    use crate::sql::schema::User;
272    use crate::sql::traits::EntitySqlLogic;
273    use chrono::{Duration, Utc};
274    use potato_head::create_uuid7;
275    use rand::Rng;
276    use scouter_mocks::init_tracing;
277    use scouter_semver::VersionType;
278    use scouter_settings::ObjectStorageSettings;
279    use scouter_types::genai::ExecutionPlan;
280    use scouter_types::psi::{Bin, BinType, PsiDriftConfig, PsiFeatureDriftProfile};
281    use scouter_types::spc::SpcDriftProfile;
282    use scouter_types::sql::TraceFilters;
283    use scouter_types::*;
284    use serde_json::Value;
285
286    const SPACE: &str = "space";
287    const NAME: &str = "name";
288    const VERSION: &str = "1.0.0";
289    const SCOPE: &str = "scope";
290    const ENTITY_ID: i32 = 9999;
291
292    fn random_trace_record() -> TraceRecord {
293        let mut rng = rand::rng();
294        let random_num = rng.random_range(0..1000);
295        let trace_id = TraceId::from_bytes(rng.random::<[u8; 16]>());
296        let span_id = SpanId::from_bytes(rng.random::<[u8; 8]>());
297        let created_at = Utc::now() + chrono::Duration::milliseconds(random_num);
298
299        TraceRecord {
300            trace_id,
301            created_at,
302            service_name: format!("service_{}", random_num % 10),
303            scope_name: SCOPE.to_string(),
304            scope_version: None,
305            trace_state: "running".to_string(),
306            start_time: created_at,
307            end_time: created_at + chrono::Duration::milliseconds(150),
308            duration_ms: 150,
309            status_code: 0,
310            span_count: 1,
311            status_message: "OK".to_string(),
312            root_span_id: span_id,
313            tags: vec![],
314            process_attributes: vec![],
315        }
316    }
317
318    fn random_span_record(
319        trace_id: &TraceId,
320        parent_span_id: Option<&SpanId>,
321        service_name: &str,
322        minutes_offset: i64,
323        uid: Option<String>,
324    ) -> TraceSpanRecord {
325        let mut rng = rand::rng();
326        let span_id = SpanId::from_bytes(rng.random::<[u8; 8]>());
327
328        let random_offset_ms = rng.random_range(0..1000);
329        let duration_ms_val = rng.random_range(50..500);
330
331        let created_at = Utc::now() - Duration::minutes(minutes_offset)
332            + chrono::Duration::milliseconds(random_offset_ms);
333        let start_time = created_at;
334        let end_time = start_time + chrono::Duration::milliseconds(duration_ms_val);
335
336        let status_code = if rng.random_bool(0.95) { 0 } else { 2 };
337        let span_kind_options = ["SERVER", "CLIENT", "INTERNAL", "PRODUCER", "CONSUMER"];
338        let span_kind = span_kind_options[rng.random_range(0..span_kind_options.len())].to_string();
339        let mut attributes = vec![];
340
341        // randomly add SCOUTER_ENTITY to attributes based on 30% chance
342        attributes.push(Attribute {
343            key: "random_attribute".to_string(),
344            value: Value::String(format!("value_{}", rng.random_range(0..100))),
345        });
346
347        attributes.push(Attribute {
348            key: SCOUTER_ENTITY.to_string(),
349            value: Value::String(uid.unwrap_or(create_uuid7())),
350        });
351
352        if rng.random_bool(0.1) {
353            attributes.push(Attribute {
354                key: "component".to_string(),
355                value: Value::String("kafka".to_string()),
356            });
357        }
358
359        TraceSpanRecord {
360            created_at,
361            span_id,
362            trace_id: trace_id.clone(),
363            parent_span_id: parent_span_id.cloned(),
364            flags: 1,
365            trace_state: String::new(),
366            service_name: service_name.to_string(),
367            scope_name: SCOPE.to_string(),
368            scope_version: None,
369            span_name: format!("random_operation_{}", rng.random_range(0..10)),
370            span_kind,
371            start_time,
372            end_time,
373            duration_ms: duration_ms_val,
374            status_code,
375            status_message: if status_code == 2 {
376                "Internal Server Error".to_string()
377            } else {
378                "OK".to_string()
379            },
380            attributes,
381            events: vec![],
382            links: vec![],
383            label: None,
384            input: Value::default(),
385            output: Value::default(),
386            resource_attributes: vec![],
387        }
388    }
389
390    fn generate_trace_with_spans(
391        num_spans: usize,
392        minutes_offset: i64,
393    ) -> (TraceRecord, Vec<TraceSpanRecord>) {
394        let trace_record = random_trace_record();
395        let mut spans: Vec<TraceSpanRecord> = Vec::new();
396        let mut rng = rand::rng();
397
398        for i in 0..num_spans {
399            let parent_span_id = if i == 0 {
400                None
401            } else {
402                Some(&spans[rng.random_range(0..spans.len())].span_id)
403            };
404            let span_record = random_span_record(
405                &trace_record.trace_id,
406                parent_span_id,
407                &trace_record.service_name,
408                minutes_offset,
409                None,
410            );
411            spans.push(span_record);
412        }
413
414        (trace_record, spans)
415    }
416
417    pub async fn cleanup(pool: &Pool<Postgres>) {
418        sqlx::raw_sql(
419            r#"
420
421            DELETE
422            FROM scouter.drift_entities;
423
424            DELETE
425            FROM scouter.spc_drift;
426
427            DELETE
428            FROM scouter.observability_metric;
429
430            DELETE
431            FROM scouter.custom_drift;
432
433            DELETE
434            FROM scouter.drift_alert;
435
436            DELETE
437            FROM scouter.drift_profile;
438
439            DELETE
440            FROM scouter.psi_drift;
441
442            DELETE
443            FROM scouter.user;
444
445            DELETE
446            FROM scouter.genai_eval_record;
447
448            DELETE
449            FROM scouter.genai_eval_task;
450
451            DELETE
452            FROM scouter.genai_eval_workflow;
453
454            DELETE
455            FROM scouter.spans;
456
457            DELETE
458            FROM scouter.traces;
459
460            DELETE
461            FROM scouter.trace_entities;
462
463            DELETE
464            FROM scouter.trace_baggage;
465
466            DELETE
467            FROM scouter.tags;
468            "#,
469        )
470        .fetch_all(pool)
471        .await
472        .unwrap();
473    }
474
475    pub async fn db_pool() -> Pool<Postgres> {
476        let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
477            .await
478            .unwrap();
479        cleanup(&pool).await;
480        pool
481    }
482
483    pub async fn insert_profile_to_db(
484        pool: &Pool<Postgres>,
485        profile: &DriftProfile,
486        active: bool,
487        deactivate_others: bool,
488    ) -> String {
489        let base_args = profile.get_base_args();
490        let version_request = VersionRequest {
491            version: None,
492            version_type: VersionType::Minor,
493            pre_tag: None,
494            build_tag: None,
495        };
496        let version = PostgresClient::get_next_profile_version(pool, &base_args, version_request)
497            .await
498            .unwrap();
499
500        let result = PostgresClient::insert_drift_profile(
501            pool,
502            profile,
503            &base_args,
504            &version,
505            &active,
506            &deactivate_others,
507        )
508        .await
509        .unwrap();
510
511        result
512    }
513
514    #[tokio::test]
515    async fn test_postgres_start() {
516        let _pool = db_pool().await;
517    }
518
519    #[tokio::test]
520    async fn test_postgres_drift_alert() {
521        let pool = db_pool().await;
522        let entity_id = 9999;
523
524        // Insert 10 alerts with slight delays to ensure different timestamps
525        for i in 0..10 {
526            let alert = AlertMap::Custom(custom::ComparisonMetricAlert {
527                metric_name: "test".to_string(),
528                baseline_value: 0 as f64,
529                observed_value: i as f64,
530                delta: None,
531                alert_threshold: AlertThreshold::Above,
532            });
533
534            let result = PostgresClient::insert_drift_alert(&pool, &entity_id, &alert)
535                .await
536                .unwrap();
537
538            assert_eq!(result.rows_affected(), 1);
539
540            // Small delay to ensure timestamp ordering
541            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
542        }
543
544        // Test 1: Get first page with default limit (50)
545        let request = DriftAlertPaginationRequest {
546            uid: create_uuid7(),
547            active: Some(true),
548            limit: Some(50),
549            ..Default::default()
550        };
551
552        let response = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
553            .await
554            .unwrap();
555
556        assert_eq!(response.items.len(), 10);
557        assert!(!response.has_next); // No more items
558        assert!(!response.has_previous); // First page
559        assert!(response.next_cursor.is_none());
560        assert!(response.previous_cursor.is_some());
561
562        // Test 2: Paginate with limit of 3 - forward direction
563        let request = DriftAlertPaginationRequest {
564            uid: create_uuid7(),
565            active: Some(true),
566            limit: Some(3),
567            direction: Some("next".to_string()),
568            ..Default::default()
569        };
570
571        let page1 = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
572            .await
573            .unwrap();
574
575        assert_eq!(page1.items.len(), 3);
576        assert!(page1.has_next);
577        assert!(!page1.has_previous);
578        assert!(page1.next_cursor.is_some());
579        assert!(page1.previous_cursor.is_some());
580
581        // Test 3: Get second page using next cursor
582        let next_cursor = page1.next_cursor.unwrap();
583        let request = DriftAlertPaginationRequest {
584            uid: create_uuid7(),
585            active: Some(true),
586            limit: Some(3),
587            cursor_created_at: Some(next_cursor.created_at),
588            cursor_id: Some(next_cursor.id as i32),
589            direction: Some("next".to_string()),
590            start_datetime: None,
591            end_datetime: None,
592        };
593
594        let page2 = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
595            .await
596            .unwrap();
597
598        assert_eq!(page2.items.len(), 3);
599        assert!(page2.has_next); // More items available
600        assert!(page2.has_previous); // Can go back
601        assert!(page2.next_cursor.is_some());
602        assert!(page2.previous_cursor.is_some());
603
604        // Verify no overlap between pages
605        let page1_ids: std::collections::HashSet<_> = page1.items.iter().map(|a| a.id).collect();
606        let page2_ids: std::collections::HashSet<_> = page2.items.iter().map(|a| a.id).collect();
607        assert!(page1_ids.is_disjoint(&page2_ids));
608
609        // Test 4: Navigate backward using previous cursor
610        let prev_cursor = page2.previous_cursor.unwrap();
611        let request = DriftAlertPaginationRequest {
612            uid: create_uuid7(),
613            active: Some(true),
614            limit: Some(3),
615            cursor_created_at: Some(prev_cursor.created_at),
616            cursor_id: Some(prev_cursor.id as i32),
617            direction: Some("previous".to_string()),
618            start_datetime: None,
619            end_datetime: None,
620        };
621
622        let page_back = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
623            .await
624            .unwrap();
625
626        assert_eq!(page_back.items.len(), 3);
627        assert!(page_back.has_next); // Can go forward
628                                     // Should return to page1 items
629        assert_eq!(
630            page_back.items.iter().map(|a| a.id).collect::<Vec<_>>(),
631            page1.items.iter().map(|a| a.id).collect::<Vec<_>>()
632        );
633
634        // Test 5: Filter by active status
635        // Deactivate some alerts first
636        let to_deactivate = &page1.items[0];
637        let update_request = UpdateAlertStatus {
638            id: to_deactivate.id,
639            active: false,
640            space: "test_space".to_string(),
641        };
642
643        PostgresClient::update_drift_alert_status(&pool, &update_request)
644            .await
645            .unwrap();
646
647        // Query only active alerts
648        let request = DriftAlertPaginationRequest {
649            uid: create_uuid7(),
650            active: Some(true),
651            limit: Some(50),
652            ..Default::default()
653        };
654
655        let active_alerts = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
656            .await
657            .unwrap();
658
659        assert_eq!(active_alerts.items.len(), 9); // 10 - 1 deactivated
660        assert!(active_alerts.items.iter().all(|a| a.active));
661
662        // Query only inactive alerts
663        let request = DriftAlertPaginationRequest {
664            uid: create_uuid7(),
665            active: Some(false),
666            limit: Some(50),
667            ..Default::default()
668        };
669
670        let inactive_alerts =
671            PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
672                .await
673                .unwrap();
674
675        assert_eq!(inactive_alerts.items.len(), 1);
676        assert!(!inactive_alerts.items[0].active);
677
678        // Test 6: Query all alerts (active and inactive)
679        let request = DriftAlertPaginationRequest {
680            uid: create_uuid7(),
681            active: None, // No filter
682            limit: Some(50),
683            ..Default::default()
684        };
685
686        let all_alerts = PostgresClient::get_paginated_drift_alerts(&pool, &request, &entity_id)
687            .await
688            .unwrap();
689
690        assert_eq!(all_alerts.items.len(), 10);
691
692        // Test 7: Empty result set
693        let request = DriftAlertPaginationRequest {
694            uid: create_uuid7(),
695            active: Some(true),
696            limit: Some(3),
697            ..Default::default()
698        };
699
700        let non_existent = PostgresClient::get_paginated_drift_alerts(&pool, &request, &999999)
701            .await
702            .unwrap();
703
704        assert_eq!(non_existent.items.len(), 0);
705        assert!(!non_existent.has_next);
706        assert!(!non_existent.has_previous);
707        assert!(non_existent.next_cursor.is_none());
708        assert!(non_existent.previous_cursor.is_none());
709    }
710
711    #[tokio::test]
712    async fn test_postgres_spc_drift_record() {
713        let pool = db_pool().await;
714
715        let record1 = SpcRecord {
716            created_at: Utc::now(),
717            uid: create_uuid7(),
718            feature: "test".to_string(),
719            value: 1.0,
720            entity_id: 0,
721        };
722
723        let record2 = SpcRecord {
724            created_at: Utc::now(),
725            uid: create_uuid7(),
726            feature: "test2".to_string(),
727            value: 2.0,
728            entity_id: 0,
729        };
730
731        let result =
732            PostgresClient::insert_spc_drift_records_batch(&pool, &[record1, record2], &ENTITY_ID)
733                .await
734                .unwrap();
735
736        assert_eq!(result.rows_affected(), 2);
737    }
738
739    #[tokio::test]
740    async fn test_postgres_bin_count() {
741        let pool = db_pool().await;
742
743        let record1 = PsiRecord {
744            created_at: Utc::now(),
745            uid: create_uuid7(),
746            feature: "test".to_string(),
747            bin_id: 1,
748            bin_count: 1,
749            entity_id: ENTITY_ID,
750        };
751
752        let record2 = PsiRecord {
753            created_at: Utc::now(),
754            uid: create_uuid7(),
755            feature: "test2".to_string(),
756            bin_id: 2,
757            bin_count: 2,
758            entity_id: ENTITY_ID,
759        };
760
761        let result =
762            PostgresClient::insert_bin_counts_batch(&pool, &[record1, record2], &ENTITY_ID)
763                .await
764                .unwrap();
765
766        assert_eq!(result.rows_affected(), 2);
767    }
768
769    #[tokio::test]
770    async fn test_postgres_observability_record() {
771        let pool = db_pool().await;
772
773        let record = ObservabilityMetrics::default();
774
775        let result = PostgresClient::insert_observability_record(&pool, &record, &ENTITY_ID)
776            .await
777            .unwrap();
778
779        assert_eq!(result.rows_affected(), 1);
780    }
781
782    #[tokio::test]
783    async fn test_postgres_crud_drift_profile() {
784        let pool = db_pool().await;
785
786        let mut spc_profile = SpcDriftProfile::default();
787        let profile = DriftProfile::Spc(spc_profile.clone());
788        let uid = insert_profile_to_db(&pool, &profile, false, false).await;
789        assert!(!uid.is_empty());
790
791        let entity_id = PostgresClient::get_entity_id_from_uid(&pool, &uid)
792            .await
793            .unwrap();
794
795        spc_profile.scouter_version = "test".to_string();
796
797        let result = PostgresClient::update_drift_profile(
798            &pool,
799            &DriftProfile::Spc(spc_profile.clone()),
800            &entity_id,
801        )
802        .await
803        .unwrap();
804
805        assert_eq!(result.rows_affected(), 1);
806
807        let profile = PostgresClient::get_drift_profile(&pool, &entity_id)
808            .await
809            .unwrap();
810
811        let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
812
813        assert_eq!(deserialized, spc_profile);
814
815        PostgresClient::update_drift_profile_status(
816            &pool,
817            &ProfileStatusRequest {
818                name: spc_profile.config.name.clone(),
819                space: spc_profile.config.space.clone(),
820                version: spc_profile.config.version.clone(),
821                active: false,
822                drift_type: Some(DriftType::Spc),
823                deactivate_others: false,
824            },
825        )
826        .await
827        .unwrap();
828    }
829
830    #[tokio::test]
831    async fn test_postgres_get_features() {
832        let pool = db_pool().await;
833
834        let timestamp = Utc::now();
835
836        for _ in 0..10 {
837            let mut records = Vec::new();
838            for j in 0..10 {
839                let record = SpcRecord {
840                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
841                    uid: create_uuid7(),
842                    feature: format!("test{j}"),
843                    value: j as f64,
844                    entity_id: ENTITY_ID,
845                };
846
847                records.push(record);
848            }
849
850            let result =
851                PostgresClient::insert_spc_drift_records_batch(&pool, &records, &ENTITY_ID)
852                    .await
853                    .unwrap();
854            assert_eq!(result.rows_affected(), records.len() as u64);
855        }
856
857        let features = PostgresClient::get_spc_features(&pool, &ENTITY_ID)
858            .await
859            .unwrap();
860        assert_eq!(features.len(), 10);
861
862        let records =
863            PostgresClient::get_spc_drift_records(&pool, &timestamp, &features, &ENTITY_ID)
864                .await
865                .unwrap();
866
867        assert_eq!(records.features.len(), 10);
868
869        let binned_records = PostgresClient::get_binned_spc_drift_records(
870            &pool,
871            &DriftRequest {
872                uid: create_uuid7(),
873                time_interval: TimeInterval::FifteenMinutes,
874                max_data_points: 10,
875                ..Default::default()
876            },
877            &DatabaseSettings::default().retention_period,
878            &ObjectStorageSettings::default(),
879            &ENTITY_ID,
880        )
881        .await
882        .unwrap();
883
884        assert_eq!(binned_records.features.len(), 10);
885    }
886
887    #[tokio::test]
888    async fn test_postgres_bin_proportions() {
889        let pool = db_pool().await;
890
891        let timestamp = Utc::now();
892
893        let num_features = 3;
894        let num_bins = 5;
895
896        let features = (0..=num_features)
897            .map(|feature| {
898                let bins = (0..=num_bins)
899                    .map(|bind_id| Bin {
900                        id: bind_id,
901                        lower_limit: None,
902                        upper_limit: None,
903                        proportion: 0.0,
904                    })
905                    .collect();
906                let feature_name = format!("feature{feature}");
907                let feature_profile = PsiFeatureDriftProfile {
908                    id: feature_name.clone(),
909                    bins,
910                    timestamp,
911                    bin_type: BinType::Numeric,
912                };
913                (feature_name, feature_profile)
914            })
915            .collect();
916
917        let profile = &DriftProfile::Psi(psi::PsiDriftProfile::new(
918            features,
919            PsiDriftConfig {
920                space: SPACE.to_string(),
921                name: NAME.to_string(),
922                version: VERSION.to_string(),
923                ..Default::default()
924            },
925        ));
926        let uid = insert_profile_to_db(&pool, profile, false, false).await;
927        let entity_id = PostgresClient::get_entity_id_from_uid(&pool, &uid)
928            .await
929            .unwrap();
930
931        for feature in 0..num_features {
932            for bin in 0..=num_bins {
933                let mut records = Vec::new();
934                for j in 0..=100 {
935                    let record = PsiRecord {
936                        created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
937                        uid: create_uuid7(),
938                        feature: format!("feature{feature}"),
939                        bin_id: bin,
940                        bin_count: rand::rng().random_range(0..10) as i32,
941                        entity_id: ENTITY_ID,
942                    };
943
944                    records.push(record);
945                }
946                PostgresClient::insert_bin_counts_batch(&pool, &records, &entity_id)
947                    .await
948                    .unwrap();
949            }
950        }
951
952        let binned_records = PostgresClient::get_feature_distributions(
953            &pool,
954            &timestamp,
955            &["feature0".to_string()],
956            &entity_id,
957        )
958        .await
959        .unwrap();
960
961        // assert binned_records.features["test"]["decile_1"] is around .5
962        let bin_proportion = binned_records
963            .distributions
964            .get("feature0")
965            .unwrap()
966            .bins
967            .get(&1)
968            .unwrap();
969
970        assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
971
972        let binned_records = PostgresClient::get_binned_psi_drift_records(
973            &pool,
974            &DriftRequest {
975                uid: create_uuid7(),
976                time_interval: TimeInterval::OneHour,
977                max_data_points: 1000,
978                ..Default::default()
979            },
980            &DatabaseSettings::default().retention_period,
981            &ObjectStorageSettings::default(),
982            &entity_id,
983        )
984        .await
985        .unwrap();
986        //
987        assert_eq!(binned_records.len(), 3);
988    }
989
990    #[tokio::test]
991    async fn test_postgres_cru_custom_metric() {
992        let pool = db_pool().await;
993        let timestamp = Utc::now();
994
995        let (uid, entity_id) = PostgresClient::create_entity(
996            &pool,
997            SPACE,
998            NAME,
999            VERSION,
1000            DriftType::Custom.to_string(),
1001        )
1002        .await
1003        .unwrap();
1004
1005        for i in 0..2 {
1006            let mut records = Vec::new();
1007            for j in 0..25 {
1008                let record = CustomMetricRecord {
1009                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1010                    uid: uid.clone(),
1011                    metric: format!("metric{i}"),
1012                    value: rand::rng().random_range(0..10) as f64,
1013                    entity_id: ENTITY_ID,
1014                };
1015                records.push(record);
1016            }
1017            let result =
1018                PostgresClient::insert_custom_metric_values_batch(&pool, &records, &entity_id)
1019                    .await
1020                    .unwrap();
1021            assert_eq!(result.rows_affected(), 25);
1022        }
1023
1024        // insert random record to test has statistics funcs handle single record
1025        let record = CustomMetricRecord {
1026            created_at: Utc::now(),
1027            uid: uid.clone(),
1028            metric: "metric3".to_string(),
1029            value: rand::rng().random_range(0..10) as f64,
1030            entity_id: ENTITY_ID,
1031        };
1032
1033        let result =
1034            PostgresClient::insert_custom_metric_values_batch(&pool, &[record], &entity_id)
1035                .await
1036                .unwrap();
1037        assert_eq!(result.rows_affected(), 1);
1038
1039        let metrics = PostgresClient::get_custom_metric_values(
1040            &pool,
1041            &timestamp,
1042            &["metric1".to_string()],
1043            &entity_id,
1044        )
1045        .await
1046        .unwrap();
1047
1048        assert_eq!(metrics.len(), 1);
1049
1050        let binned_records = PostgresClient::get_binned_custom_drift_records(
1051            &pool,
1052            &DriftRequest {
1053                uid: uid.clone(),
1054                time_interval: TimeInterval::OneHour,
1055                max_data_points: 1000,
1056                ..Default::default()
1057            },
1058            &DatabaseSettings::default().retention_period,
1059            &ObjectStorageSettings::default(),
1060            &entity_id,
1061        )
1062        .await
1063        .unwrap();
1064        //
1065        assert_eq!(binned_records.metrics.len(), 3);
1066    }
1067
1068    #[tokio::test]
1069    async fn test_postgres_user() {
1070        let pool = db_pool().await;
1071        let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
1072
1073        // Create
1074        let user = User::new(
1075            "user".to_string(),
1076            "pass".to_string(),
1077            "email".to_string(),
1078            recovery_codes,
1079            None,
1080            None,
1081            None,
1082            None,
1083        );
1084        PostgresClient::insert_user(&pool, &user).await.unwrap();
1085
1086        // Read
1087        let mut user = PostgresClient::get_user(&pool, "user")
1088            .await
1089            .unwrap()
1090            .unwrap();
1091
1092        assert_eq!(user.username, "user");
1093        assert_eq!(user.group_permissions, vec!["user"]);
1094        assert_eq!(user.email, "email");
1095
1096        // update user
1097        user.active = false;
1098        user.refresh_token = Some("token".to_string());
1099
1100        // Update
1101        PostgresClient::update_user(&pool, &user).await.unwrap();
1102        let user = PostgresClient::get_user(&pool, "user")
1103            .await
1104            .unwrap()
1105            .unwrap();
1106        assert!(!user.active);
1107        assert_eq!(user.refresh_token.unwrap(), "token");
1108
1109        // get users
1110        let users = PostgresClient::get_users(&pool).await.unwrap();
1111        assert_eq!(users.len(), 1);
1112
1113        // get last admin
1114        let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
1115        assert!(!is_last_admin);
1116
1117        // delete
1118        PostgresClient::delete_user(&pool, "user").await.unwrap();
1119    }
1120
1121    #[tokio::test]
1122    async fn test_postgres_genai_eval_record_reschedule() {
1123        let pool = db_pool().await;
1124
1125        let (uid, entity_id) = PostgresClient::create_entity(
1126            &pool,
1127            SPACE,
1128            NAME,
1129            VERSION,
1130            DriftType::GenAI.to_string(),
1131        )
1132        .await
1133        .unwrap();
1134
1135        let input = "This is a test input";
1136        let output = "This is a test response";
1137
1138        for j in 0..10 {
1139            let context = serde_json::json!({
1140                "input": input,
1141                "response": output,
1142            });
1143            let record = GenAIEvalRecord {
1144                created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1145                context,
1146                status: Status::Pending,
1147                id: 0, // This will be set by the database
1148                uid: format!("test_{}", j),
1149                entity_uid: uid.clone(),
1150                entity_id,
1151                ..Default::default()
1152            };
1153
1154            let boxed = BoxedGenAIEvalRecord::new(record);
1155
1156            let result = PostgresClient::insert_genai_eval_record(&pool, boxed, &entity_id)
1157                .await
1158                .unwrap();
1159
1160            assert_eq!(result.rows_affected(), 1);
1161        }
1162
1163        let features = PostgresClient::get_genai_eval_records(&pool, None, None, &entity_id)
1164            .await
1165            .unwrap();
1166        assert_eq!(features.len(), 10);
1167
1168        // get pending task
1169        let pending_tasks = PostgresClient::get_pending_genai_eval_record(&pool)
1170            .await
1171            .unwrap();
1172
1173        // assert not empty
1174        assert!(pending_tasks.is_some());
1175
1176        // get pending task with space, name, version
1177        let task_input = &pending_tasks.as_ref().unwrap().context["input"];
1178        assert_eq!(*task_input, "This is a test input".to_string());
1179
1180        // reschedule task
1181        PostgresClient::reschedule_genai_eval_record(
1182            &pool,
1183            &pending_tasks.as_ref().unwrap().uid,
1184            Duration::seconds(30),
1185        )
1186        .await
1187        .unwrap();
1188    }
1189
1190    #[tokio::test]
1191    async fn test_postgres_genai_eval_record_insert_get() {
1192        let pool = db_pool().await;
1193
1194        let (uid, entity_id) = PostgresClient::create_entity(
1195            &pool,
1196            SPACE,
1197            NAME,
1198            VERSION,
1199            DriftType::GenAI.to_string(),
1200        )
1201        .await
1202        .unwrap();
1203
1204        let input = "This is a test input";
1205        let output = "This is a test response";
1206
1207        for j in 0..10 {
1208            let context = serde_json::json!({
1209                "input": input,
1210                "response": output,
1211            });
1212            let record = GenAIEvalRecord {
1213                created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1214                context,
1215                status: Status::Pending,
1216                id: 0, // This will be set by the database
1217                uid: format!("test_{}", j),
1218                entity_uid: uid.clone(),
1219                entity_id,
1220                ..Default::default()
1221            };
1222
1223            let boxed = BoxedGenAIEvalRecord::new(record);
1224
1225            let result = PostgresClient::insert_genai_eval_record(&pool, boxed, &entity_id)
1226                .await
1227                .unwrap();
1228
1229            assert_eq!(result.rows_affected(), 1);
1230        }
1231
1232        let features = PostgresClient::get_genai_eval_records(&pool, None, None, &entity_id)
1233            .await
1234            .unwrap();
1235        assert_eq!(features.len(), 10);
1236
1237        // get pending task
1238        let pending_tasks = PostgresClient::get_pending_genai_eval_record(&pool)
1239            .await
1240            .unwrap();
1241
1242        // assert not empty
1243        assert!(pending_tasks.is_some());
1244
1245        // get pending task with space, name, version
1246        let task_input = &pending_tasks.as_ref().unwrap().context["input"];
1247        assert_eq!(*task_input, "This is a test input".to_string());
1248
1249        // update pending task
1250        PostgresClient::update_genai_eval_record_status(
1251            &pool,
1252            &pending_tasks.unwrap(),
1253            Status::Processed,
1254            &(1_i64),
1255        )
1256        .await
1257        .unwrap();
1258
1259        // query processed tasks
1260        let processed_tasks = PostgresClient::get_genai_eval_records(
1261            &pool,
1262            None,
1263            Some(Status::Processed),
1264            &entity_id,
1265        )
1266        .await
1267        .unwrap();
1268
1269        // assert not empty
1270        assert_eq!(processed_tasks.len(), 1);
1271    }
1272
1273    #[tokio::test]
1274    async fn test_postgres_genai_eval_record_pagination() {
1275        let pool = db_pool().await;
1276
1277        let (uid, entity_id) = PostgresClient::create_entity(
1278            &pool,
1279            SPACE,
1280            NAME,
1281            VERSION,
1282            DriftType::GenAI.to_string(),
1283        )
1284        .await
1285        .unwrap();
1286
1287        let input = "This is a test input";
1288        let output = "This is a test response";
1289
1290        // Insert 10 records with increasing timestamps
1291        for j in 0..10 {
1292            let context = serde_json::json!({
1293                "input": input,
1294                "response": output,
1295            });
1296            let record = GenAIEvalRecord {
1297                created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1298                context,
1299                status: Status::Pending,
1300                id: 0, // This will be set by the database
1301                uid: format!("test_{}", j),
1302                entity_uid: uid.clone(),
1303                entity_id: ENTITY_ID,
1304                ..Default::default()
1305            };
1306
1307            let boxed = BoxedGenAIEvalRecord::new(record);
1308
1309            let result = PostgresClient::insert_genai_eval_record(&pool, boxed, &entity_id)
1310                .await
1311                .unwrap();
1312
1313            assert_eq!(result.rows_affected(), 1);
1314        }
1315
1316        // ===== PAGE 1: Get first 5 records (newest) =====
1317        let params = GenAIEvalRecordPaginationRequest {
1318            status: None,
1319            limit: Some(5),
1320            cursor_created_at: None,
1321            cursor_id: None,
1322            direction: None,
1323            ..Default::default()
1324        };
1325
1326        let page1 = PostgresClient::get_paginated_genai_eval_records(&pool, &params, &entity_id)
1327            .await
1328            .unwrap();
1329
1330        assert_eq!(page1.items.len(), 5, "Page 1 should have 5 records");
1331        assert!(page1.has_next, "Should have next page");
1332        assert!(
1333            !page1.has_previous,
1334            "Should not have previous page (first page)"
1335        );
1336        assert!(page1.next_cursor.is_some(), "Should have next cursor");
1337
1338        // First item should be the NEWEST record (highest ID)
1339        let page1_first = page1.items.first().unwrap();
1340        let page1_last = page1.items.last().unwrap();
1341
1342        assert!(
1343            page1_first.created_at >= page1_last.created_at,
1344            "Page 1 should be sorted newest first (DESC)"
1345        );
1346
1347        // ===== PAGE 2: Get next 5 records (older) =====
1348        let next_cursor = page1.next_cursor.unwrap();
1349
1350        let params = GenAIEvalRecordPaginationRequest {
1351            status: None,
1352            limit: Some(5),
1353            cursor_created_at: Some(next_cursor.created_at),
1354            cursor_id: Some(next_cursor.id),
1355            direction: None,
1356            ..Default::default()
1357        };
1358
1359        let page2 = PostgresClient::get_paginated_genai_eval_records(&pool, &params, &entity_id)
1360            .await
1361            .unwrap();
1362
1363        assert_eq!(page2.items.len(), 5, "Page 2 should have 5 records");
1364        assert!(!page2.has_next, "Should not have next page (last page)");
1365        assert!(page2.has_previous, "Should have previous page");
1366        assert!(
1367            page2.previous_cursor.is_some(),
1368            "Should have previous cursor"
1369        );
1370
1371        let page2_first = page2.items.first().unwrap();
1372
1373        // Page 2 first item should be OLDER than Page 1 last item
1374        assert!(
1375            page2_first.created_at < page1_last.created_at
1376                || (page2_first.created_at == page1_last.created_at
1377                    && page2_first.id < page1_last.id),
1378            "Page 2 should start with records older than Page 1 last item"
1379        );
1380
1381        // Verify we got all 10 records across both pages
1382        let all_ids: Vec<i64> = page1
1383            .items
1384            .iter()
1385            .chain(page2.items.iter())
1386            .map(|r| r.id)
1387            .collect();
1388
1389        assert_eq!(all_ids.len(), 10, "Should have 10 unique records total");
1390
1391        // All IDs should be unique
1392        let unique_ids: std::collections::HashSet<_> = all_ids.iter().collect();
1393        assert_eq!(unique_ids.len(), 10, "All IDs should be unique");
1394
1395        // ===== BACKWARD PAGINATION TEST =====
1396        // Go back from page 2 to page 1
1397        let previous_cursor = page2.previous_cursor.unwrap();
1398
1399        let params = GenAIEvalRecordPaginationRequest {
1400            status: None,
1401            limit: Some(5),
1402            cursor_created_at: Some(previous_cursor.created_at),
1403            cursor_id: Some(previous_cursor.id),
1404            direction: Some("previous".to_string()),
1405            ..Default::default()
1406        };
1407
1408        let page1_again =
1409            PostgresClient::get_paginated_genai_eval_records(&pool, &params, &entity_id)
1410                .await
1411                .unwrap();
1412
1413        assert_eq!(
1414            page1_again.items.len(),
1415            5,
1416            "Going back should return 5 records"
1417        );
1418
1419        // Should get the same records as page 1
1420        assert_eq!(
1421            page1_again.items.first().unwrap().id,
1422            page1_first.id,
1423            "Should return to the same first record"
1424        );
1425    }
1426
1427    #[tokio::test]
1428    async fn test_postgres_genai_eval_workflow_pagination() {
1429        let pool = db_pool().await;
1430
1431        let (_uid, entity_id) = PostgresClient::create_entity(
1432            &pool,
1433            SPACE,
1434            NAME,
1435            VERSION,
1436            DriftType::GenAI.to_string(),
1437        )
1438        .await
1439        .unwrap();
1440
1441        // Insert 10 records with increasing timestamps
1442        for j in 0..10 {
1443            let record = GenAIEvalWorkflowResult {
1444                created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1445                record_uid: format!("test_{}", j),
1446                entity_id,
1447                ..Default::default()
1448            };
1449
1450            let result =
1451                PostgresClient::insert_genai_eval_workflow_record(&pool, &record, &entity_id)
1452                    .await
1453                    .unwrap();
1454
1455            assert_eq!(result.rows_affected(), 1);
1456        }
1457
1458        // ===== PAGE 1: Get first 5 records (newest) =====
1459        let params = GenAIEvalRecordPaginationRequest {
1460            status: None,
1461            limit: Some(5),
1462            cursor_created_at: None,
1463            cursor_id: None,
1464            direction: None,
1465            ..Default::default()
1466        };
1467
1468        let page1 =
1469            PostgresClient::get_paginated_genai_eval_workflow_records(&pool, &params, &entity_id)
1470                .await
1471                .unwrap();
1472
1473        assert_eq!(page1.items.len(), 5, "Page 1 should have 5 records");
1474        assert!(page1.has_next, "Should have next page");
1475        assert!(
1476            !page1.has_previous,
1477            "Should not have previous page (first page)"
1478        );
1479        assert!(page1.next_cursor.is_some(), "Should have next cursor");
1480
1481        // First item should be the NEWEST record (highest ID)
1482        let page1_first = page1.items.first().unwrap();
1483        let page1_last = page1.items.last().unwrap();
1484
1485        assert!(
1486            page1_first.created_at >= page1_last.created_at,
1487            "Page 1 should be sorted newest first (DESC)"
1488        );
1489
1490        // ===== PAGE 2: Get next 5 records (older) =====
1491        let next_cursor = page1.next_cursor.unwrap();
1492
1493        let params = GenAIEvalRecordPaginationRequest {
1494            status: None,
1495            limit: Some(5),
1496            cursor_created_at: Some(next_cursor.created_at),
1497            cursor_id: Some(next_cursor.id),
1498            direction: None,
1499            ..Default::default()
1500        };
1501
1502        let page2 =
1503            PostgresClient::get_paginated_genai_eval_workflow_records(&pool, &params, &entity_id)
1504                .await
1505                .unwrap();
1506
1507        assert_eq!(page2.items.len(), 5, "Page 2 should have 5 records");
1508        assert!(!page2.has_next, "Should not have next page (last page)");
1509        assert!(page2.has_previous, "Should have previous page");
1510        assert!(
1511            page2.previous_cursor.is_some(),
1512            "Should have previous cursor"
1513        );
1514
1515        let page2_first = page2.items.first().unwrap();
1516
1517        // Page 2 first item should be OLDER than Page 1 last item
1518        assert!(
1519            page2_first.created_at < page1_last.created_at
1520                || (page2_first.created_at == page1_last.created_at
1521                    && page2_first.id < page1_last.id),
1522            "Page 2 should start with records older than Page 1 last item"
1523        );
1524
1525        // Verify we got all 10 records across both pages
1526        let all_ids: Vec<i64> = page1
1527            .items
1528            .iter()
1529            .chain(page2.items.iter())
1530            .map(|r| r.id)
1531            .collect();
1532
1533        assert_eq!(all_ids.len(), 10, "Should have 10 unique records total");
1534
1535        // All IDs should be unique
1536        let unique_ids: std::collections::HashSet<_> = all_ids.iter().collect();
1537        assert_eq!(unique_ids.len(), 10, "All IDs should be unique");
1538
1539        // ===== BACKWARD PAGINATION TEST =====
1540        // Go back from page 2 to page 1
1541        let previous_cursor = page2.previous_cursor.unwrap();
1542
1543        let params = GenAIEvalRecordPaginationRequest {
1544            status: None,
1545            limit: Some(5),
1546            cursor_created_at: Some(previous_cursor.created_at),
1547            cursor_id: Some(previous_cursor.id),
1548            direction: Some("previous".to_string()),
1549            ..Default::default()
1550        };
1551
1552        let page1_again =
1553            PostgresClient::get_paginated_genai_eval_workflow_records(&pool, &params, &entity_id)
1554                .await
1555                .unwrap();
1556
1557        assert_eq!(
1558            page1_again.items.len(),
1559            5,
1560            "Going back should return 5 records"
1561        );
1562
1563        // Should get the same records as page 1
1564        assert_eq!(
1565            page1_again.items.first().unwrap().id,
1566            page1_first.id,
1567            "Should return to the same first record"
1568        );
1569    }
1570
1571    #[tokio::test]
1572    async fn test_postgres_genai_task_result_insert_get() {
1573        let pool = db_pool().await;
1574
1575        let timestamp = Utc::now();
1576
1577        let (uid, entity_id) = PostgresClient::create_entity(
1578            &pool,
1579            SPACE,
1580            NAME,
1581            VERSION,
1582            DriftType::GenAI.to_string(),
1583        )
1584        .await
1585        .unwrap();
1586
1587        let mut records = Vec::new();
1588        for i in 0..2 {
1589            for j in 0..25 {
1590                let record = GenAIEvalTaskResult {
1591                    record_uid: format!("record_uid_{i}_{j}"),
1592                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
1593                    start_time: Utc::now(),
1594                    end_time: Utc::now() + chrono::Duration::seconds(1),
1595                    entity_id,
1596                    task_id: format!("task{i}"),
1597                    task_type: scouter_types::genai::EvaluationTaskType::Assertion,
1598                    passed: true,
1599                    value: j as f64,
1600                    assertion: Assertion::FieldPath(Some(format!("field.path.{i}"))),
1601                    operator: scouter_types::genai::ComparisonOperator::Contains,
1602                    expected: Value::Null,
1603                    actual: Value::Null,
1604                    message: "All good".to_string(),
1605                    entity_uid: uid.clone(),
1606                    condition: false,
1607                    stage: 0_i32,
1608                };
1609                records.push(record);
1610            }
1611            let result =
1612                PostgresClient::insert_eval_task_results_batch(&pool, &records, &entity_id)
1613                    .await
1614                    .unwrap();
1615            assert_eq!(result.rows_affected(), 25);
1616        }
1617
1618        let metrics = PostgresClient::get_genai_task_values(
1619            &pool,
1620            &timestamp,
1621            &["task1".to_string()],
1622            &entity_id,
1623        )
1624        .await
1625        .unwrap();
1626
1627        assert_eq!(metrics.len(), 1);
1628        let binned_records = PostgresClient::get_binned_genai_task_values(
1629            &pool,
1630            &DriftRequest {
1631                uid: uid.clone(),
1632                time_interval: TimeInterval::OneHour,
1633                max_data_points: 1000,
1634                ..Default::default()
1635            },
1636            &DatabaseSettings::default().retention_period,
1637            &ObjectStorageSettings::default(),
1638            &entity_id,
1639        )
1640        .await
1641        .unwrap();
1642        //
1643        assert_eq!(binned_records.metrics.len(), 2);
1644
1645        let eval_task = PostgresClient::get_genai_eval_task(&pool, &records[0].record_uid)
1646            .await
1647            .unwrap();
1648
1649        assert_eq!(eval_task[0].record_uid, records[0].record_uid);
1650    }
1651
1652    #[tokio::test]
1653    async fn test_postgres_genai_workflow_result_insert_get() {
1654        let pool = db_pool().await;
1655
1656        let timestamp = Utc::now();
1657
1658        let (uid, entity_id) = PostgresClient::create_entity(
1659            &pool,
1660            SPACE,
1661            NAME,
1662            VERSION,
1663            DriftType::GenAI.to_string(),
1664        )
1665        .await
1666        .unwrap();
1667
1668        for i in 0..2 {
1669            for j in 0..25 {
1670                let record = GenAIEvalWorkflowResult {
1671                    record_uid: format!("record_uid_{i}_{j}"),
1672                    created_at: Utc::now() + chrono::Duration::hours(i),
1673                    entity_id,
1674                    total_tasks: 10,
1675                    passed_tasks: 8,
1676                    failed_tasks: 2,
1677                    pass_rate: 0.8,
1678                    duration_ms: 1500,
1679                    entity_uid: uid.clone(),
1680                    execution_plan: ExecutionPlan::default(),
1681                    id: 0,
1682                };
1683                let result =
1684                    PostgresClient::insert_genai_eval_workflow_record(&pool, &record, &entity_id)
1685                        .await
1686                        .unwrap();
1687                assert_eq!(result.rows_affected(), 1);
1688            }
1689        }
1690
1691        let metric = PostgresClient::get_genai_workflow_value(&pool, &timestamp, &entity_id)
1692            .await
1693            .unwrap();
1694
1695        assert!(metric.is_some());
1696
1697        let binned_records = PostgresClient::get_binned_genai_workflow_values(
1698            &pool,
1699            &DriftRequest {
1700                uid: uid.clone(),
1701                time_interval: TimeInterval::OneHour,
1702                max_data_points: 1000,
1703                ..Default::default()
1704            },
1705            &DatabaseSettings::default().retention_period,
1706            &ObjectStorageSettings::default(),
1707            &entity_id,
1708        )
1709        .await
1710        .unwrap();
1711        //
1712        assert_eq!(binned_records.metrics.len(), 1);
1713    }
1714
1715    #[tokio::test]
1716    async fn test_postgres_tracing_metrics() {
1717        init_tracing();
1718        let pool = db_pool().await;
1719
1720        // Insert 1000 trace records with random data
1721
1722        for minute in 0..100 {
1723            let (_trace_record, spans) = generate_trace_with_spans(20, minute);
1724            let _ = PostgresClient::insert_span_batch(&pool, &spans)
1725                .await
1726                .unwrap();
1727        }
1728
1729        // Wait for background flush and then force flush any remaining
1730        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
1731        let flushed = shutdown_trace_cache(&pool).await.unwrap();
1732        info!("Flushed {} traces during shutdown", flushed);
1733
1734        let mut filters = TraceFilters::default();
1735        let first_batch = PostgresClient::get_paginated_traces(&pool, filters.clone())
1736            .await
1737            .unwrap();
1738
1739        assert_eq!(
1740            first_batch.items.len(),
1741            50,
1742            "First batch should have 50 records"
1743        );
1744
1745        // test pagination (get last record created_at and trace_id)
1746        let last_record = first_batch.next_cursor.unwrap();
1747        filters = filters.next_page(&last_record);
1748
1749        let next_batch = PostgresClient::get_paginated_traces(&pool, filters.clone())
1750            .await
1751            .unwrap();
1752
1753        // should be another 50 records
1754        assert_eq!(
1755            next_batch.items.len(),
1756            50,
1757            "Next batch should have 50 records"
1758        );
1759
1760        // assert next_batch first record timestamp is <= first_batch last record timestamp
1761        let next_first_record = next_batch.items.first().unwrap();
1762        assert!(
1763            next_first_record.start_time <= last_record.start_time,
1764            "Next batch first record timestamp is not less than or equal to last record timestamp"
1765        );
1766
1767        // test pagination for previous
1768        filters = filters.previous_page(&next_batch.previous_cursor.unwrap());
1769        let previous_batch = PostgresClient::get_paginated_traces(&pool, filters.clone())
1770            .await
1771            .unwrap();
1772        assert_eq!(
1773            previous_batch.items.len(),
1774            50,
1775            "Previous batch should have 50 records"
1776        );
1777
1778        // Filter records to find item with >5 spans
1779        let filtered_record = first_batch
1780            .items
1781            .iter()
1782            .find(|record| record.span_count > 5)
1783            .unwrap();
1784
1785        filters.cursor_start_time = None;
1786        filters.cursor_trace_id = None;
1787
1788        let records = PostgresClient::get_paginated_traces(&pool, filters.clone())
1789            .await
1790            .unwrap();
1791
1792        // Records are randomly generated, so just assert we get some records back
1793        assert!(
1794            !records.items.is_empty(),
1795            "Should return records with specified filters"
1796        );
1797
1798        // get spans for filtered trace
1799        let spans = PostgresClient::get_trace_spans(&pool, &filtered_record.trace_id, None)
1800            .await
1801            .unwrap();
1802
1803        assert!(spans.len() == filtered_record.span_count as usize);
1804
1805        let start_time = filtered_record.start_time - chrono::Duration::hours(48);
1806        let end_time = filtered_record.start_time + chrono::Duration::minutes(5);
1807
1808        // make request for trace metrics
1809        let trace_metrics = PostgresClient::get_trace_metrics(
1810            &pool,
1811            None,
1812            start_time,
1813            end_time,
1814            "5 minutes",
1815            None,
1816            None,
1817        )
1818        .await
1819        .unwrap();
1820
1821        // assert we have data points (all traces fall within ~1 second, so 1-2 buckets expected)
1822        assert!(
1823            !trace_metrics.is_empty(),
1824            "Should have at least one bucket of trace metrics"
1825        );
1826
1827        // get paginated traces with tags
1828        let filters = scouter_types::sql::TraceFilters {
1829            attribute_filters: Some(vec![("component=kafka".to_string())]),
1830            ..Default::default()
1831        };
1832        let tagged_batch = PostgresClient::get_paginated_traces(&pool, filters)
1833            .await
1834            .unwrap();
1835        assert!(!tagged_batch.items.is_empty());
1836    }
1837
1838    #[tokio::test]
1839    async fn test_postgres_tracing_insert() {
1840        let pool = db_pool().await;
1841
1842        // create parent trace
1843        let mut trace_record = random_trace_record();
1844        let trace_id = trace_record.trace_id.clone();
1845        let uid = create_uuid7();
1846
1847        // create spans
1848        let root_span = random_span_record(
1849            &trace_id,
1850            None,
1851            &trace_record.service_name,
1852            0_i64,
1853            Some(uid.clone()),
1854        );
1855        let child_span = random_span_record(
1856            &trace_id,
1857            Some(&root_span.span_id),
1858            &root_span.service_name,
1859            0_i64,
1860            None,
1861        );
1862
1863        // set root span id in trace record
1864        trace_record.root_span_id = root_span.span_id.clone();
1865
1866        // insert spans
1867        let result =
1868            PostgresClient::insert_span_batch(&pool, &[root_span.clone(), child_span.clone()])
1869                .await
1870                .unwrap();
1871
1872        // wait for 1 second to ensure all records are flushed from cache to db
1873        shutdown_trace_cache(&pool).await.unwrap();
1874        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
1875
1876        assert_eq!(result.rows_affected(), 2);
1877
1878        let inserted_created_at = trace_record.created_at;
1879        let inserted_trace_id = trace_record.trace_id.clone();
1880
1881        let trace_filter = TraceFilters {
1882            cursor_start_time: Some(inserted_created_at + Duration::days(1)),
1883            cursor_trace_id: Some(inserted_trace_id.to_hex()),
1884            start_time: Some(inserted_created_at - Duration::minutes(5)),
1885            end_time: Some(inserted_created_at + Duration::days(1)),
1886            ..TraceFilters::default()
1887        };
1888
1889        let traces = PostgresClient::get_paginated_traces(&pool, trace_filter)
1890            .await
1891            .unwrap();
1892
1893        assert_eq!(traces.items.len(), 1);
1894        let retrieved_trace = &traces.items[0];
1895        // assert span count is 2
1896        assert_eq!(retrieved_trace.span_count, 2);
1897
1898        let baggage = TraceBaggageRecord {
1899            created_at: Utc::now(),
1900            trace_id: trace_record.trace_id.clone(),
1901            scope: "test_scope".to_string(),
1902            key: "user_id".to_string(),
1903            value: "12345".to_string(),
1904        };
1905
1906        let result =
1907            PostgresClient::insert_trace_baggage_batch(&pool, std::slice::from_ref(&baggage))
1908                .await
1909                .unwrap();
1910
1911        assert_eq!(result.rows_affected(), 1);
1912
1913        let trace_id = trace_record.trace_id.to_hex();
1914        let retrieved_baggage = PostgresClient::get_trace_baggage_records(&pool, &trace_id)
1915            .await
1916            .unwrap();
1917
1918        assert_eq!(retrieved_baggage.len(), 1);
1919
1920        // query by entity_uid
1921        let trace_filter = TraceFilters {
1922            start_time: Some(inserted_created_at - Duration::days(1)),
1923            end_time: Some(inserted_created_at + Duration::days(1)),
1924            entity_uid: Some(uid),
1925            ..TraceFilters::default()
1926        };
1927
1928        let traces = PostgresClient::get_paginated_traces(&pool, trace_filter)
1929            .await
1930            .unwrap();
1931
1932        assert!(!traces.items.is_empty());
1933    }
1934
1935    #[tokio::test]
1936    async fn test_postgres_tags() {
1937        let pool = db_pool().await;
1938        let uid = create_uuid7();
1939
1940        let tag1 = TagRecord {
1941            entity_id: uid.clone(),
1942            entity_type: "service".to_string(),
1943            key: "env".to_string(),
1944            value: "production".to_string(),
1945        };
1946
1947        let tag2 = TagRecord {
1948            entity_id: uid.clone(),
1949            entity_type: "service".to_string(),
1950            key: "team".to_string(),
1951            value: "backend".to_string(),
1952        };
1953
1954        let result = PostgresClient::insert_tag_batch(&pool, &[tag1.clone(), tag2.clone()])
1955            .await
1956            .unwrap();
1957
1958        assert_eq!(result.rows_affected(), 2);
1959
1960        let tags = PostgresClient::get_tags(&pool, "service", &uid)
1961            .await
1962            .unwrap();
1963
1964        assert_eq!(tags.len(), 2);
1965
1966        let tag_filter = vec![Tag {
1967            key: tags.first().unwrap().key.clone(),
1968            value: tags.first().unwrap().value.clone(),
1969        }];
1970
1971        let entity_id = PostgresClient::get_entity_id_by_tags(&pool, "service", &tag_filter, false)
1972            .await
1973            .unwrap();
1974
1975        assert_eq!(entity_id.first().unwrap(), &uid);
1976    }
1977
1978    #[tokio::test]
1979    async fn test_postgres_get_spans_from_tags() {
1980        let pool = db_pool().await;
1981
1982        // create parent trace
1983        let mut trace_record = random_trace_record();
1984        let trace_id = trace_record.trace_id.clone();
1985
1986        // create spans
1987        let root_span =
1988            random_span_record(&trace_id, None, &trace_record.service_name, 0_i64, None);
1989        let child_span = random_span_record(
1990            &trace_id,
1991            Some(&root_span.span_id),
1992            &root_span.service_name,
1993            0_i64,
1994            None,
1995        );
1996
1997        // set root span id in trace record
1998        trace_record.root_span_id = root_span.span_id.clone();
1999
2000        // insert spans
2001        let result =
2002            PostgresClient::insert_span_batch(&pool, &[root_span.clone(), child_span.clone()])
2003                .await
2004                .unwrap();
2005
2006        assert_eq!(result.rows_affected(), 2);
2007
2008        shutdown_trace_cache(&pool).await.unwrap();
2009        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
2010
2011        let tag = TagRecord {
2012            entity_id: trace_record.trace_id.to_hex(),
2013            entity_type: "trace".to_string(),
2014            key: "env".to_string(),
2015            value: "production".to_string(),
2016        };
2017
2018        let result = PostgresClient::insert_tag_batch(&pool, std::slice::from_ref(&tag))
2019            .await
2020            .unwrap();
2021
2022        assert_eq!(result.rows_affected(), 1);
2023
2024        let tag_filters = vec![HashMap::from([
2025            ("key".to_string(), "env".to_string()),
2026            ("value".to_string(), "production".to_string()),
2027        ])];
2028
2029        let spans = PostgresClient::get_spans_from_tags(&pool, "trace", tag_filters, true, None)
2030            .await
2031            .unwrap();
2032
2033        assert_eq!(spans.len(), 2);
2034    }
2035}