scouter_sql/sql/
postgres.rs

1use crate::sql::traits::{
2    AlertSqlLogic, ArchiveSqlLogic, CustomMetricSqlLogic, LLMDriftSqlLogic, ObservabilitySqlLogic,
3    ProfileSqlLogic, PsiSqlLogic, SpcSqlLogic, UserSqlLogic,
4};
5
6use crate::sql::error::SqlError;
7use scouter_settings::DatabaseSettings;
8use scouter_types::{RecordType, ServerRecords, ToDriftRecords};
9
10use sqlx::ConnectOptions;
11use sqlx::{postgres::PgConnectOptions, Pool, Postgres};
12use std::result::Result::Ok;
13use tracing::{debug, error, info, instrument};
14
15// TODO: Explore refactoring and breaking this out into multiple client types (i.e., spc, psi, etc.)
16// Postgres client is one of the lowest-level abstractions so it may not be worth it, as it could make server logic annoying. Worth exploring though.
17
18#[derive(Debug, Clone)]
19#[allow(dead_code)]
20pub struct PostgresClient {}
21
22impl SpcSqlLogic for PostgresClient {}
23impl CustomMetricSqlLogic for PostgresClient {}
24impl PsiSqlLogic for PostgresClient {}
25impl LLMDriftSqlLogic for PostgresClient {}
26impl UserSqlLogic for PostgresClient {}
27impl ProfileSqlLogic for PostgresClient {}
28impl ObservabilitySqlLogic for PostgresClient {}
29impl AlertSqlLogic for PostgresClient {}
30impl ArchiveSqlLogic for PostgresClient {}
31
32impl PostgresClient {
33    /// Setup the application with the given database pool.
34    ///
35    /// # Returns
36    ///
37    /// * `Result<Pool<Postgres>, anyhow::Error>` - Result of the database pool
38    #[instrument(skip(database_settings))]
39    pub async fn create_db_pool(
40        database_settings: &DatabaseSettings,
41    ) -> Result<Pool<Postgres>, SqlError> {
42        let mut opts: PgConnectOptions = database_settings.connection_uri.parse()?;
43
44        // Sqlx logs a lot of debug information by default, which can be overwhelming.
45
46        opts = opts.log_statements(tracing::log::LevelFilter::Off);
47
48        let pool = match sqlx::postgres::PgPoolOptions::new()
49            .max_connections(database_settings.max_connections)
50            .connect_with(opts)
51            .await
52        {
53            Ok(pool) => {
54                info!("✅ Successfully connected to database");
55                pool
56            }
57            Err(err) => {
58                error!("🚨 Failed to connect to database {:?}", err);
59                std::process::exit(1);
60            }
61        };
62
63        // Run migrations
64        if let Err(err) = Self::run_migrations(&pool).await {
65            error!("🚨 Failed to run migrations {:?}", err);
66            std::process::exit(1);
67        }
68
69        Ok(pool)
70    }
71
72    pub async fn run_migrations(pool: &Pool<Postgres>) -> Result<(), SqlError> {
73        info!("Running migrations");
74        sqlx::migrate!("src/migrations")
75            .run(pool)
76            .await
77            .map_err(SqlError::MigrateError)?;
78
79        debug!("Migrations complete");
80
81        Ok(())
82    }
83}
84
85pub struct MessageHandler {}
86
87impl MessageHandler {
88    const DEFAULT_BATCH_SIZE: usize = 500;
89    #[instrument(skip_all)]
90    pub async fn insert_server_records(
91        pool: &Pool<Postgres>,
92        records: &ServerRecords,
93    ) -> Result<(), SqlError> {
94        debug!("Inserting server records: {:?}", records.record_type()?);
95
96        match records.record_type()? {
97            RecordType::Spc => {
98                let spc_records = records.to_spc_drift_records()?;
99                debug!("SPC record count: {}", spc_records.len());
100
101                for chunk in spc_records.chunks(Self::DEFAULT_BATCH_SIZE) {
102                    PostgresClient::insert_spc_drift_records_batch(pool, chunk)
103                        .await
104                        .map_err(|e| {
105                            error!("Failed to insert SPC drift records batch: {:?}", e);
106                            e
107                        })?;
108                }
109            }
110
111            RecordType::Psi => {
112                let psi_records = records.to_psi_drift_records()?;
113                debug!("PSI record count: {}", psi_records.len());
114
115                for chunk in psi_records.chunks(Self::DEFAULT_BATCH_SIZE) {
116                    PostgresClient::insert_bin_counts_batch(pool, chunk)
117                        .await
118                        .map_err(|e| {
119                            error!("Failed to insert PSI drift records batch: {:?}", e);
120                            e
121                        })?;
122                }
123            }
124            RecordType::Custom => {
125                let custom_records = records.to_custom_metric_drift_records()?;
126                debug!("Custom record count: {}", custom_records.len());
127
128                for chunk in custom_records.chunks(Self::DEFAULT_BATCH_SIZE) {
129                    PostgresClient::insert_custom_metric_values_batch(pool, chunk)
130                        .await
131                        .map_err(|e| {
132                            error!("Failed to insert custom metric records batch: {:?}", e);
133                            e
134                        })?;
135                }
136            }
137
138            RecordType::LLMDrift => {
139                debug!("LLM Drift record count: {:?}", records.len());
140                let records = records.to_llm_drift_records()?;
141                for record in records.iter() {
142                    let _ = PostgresClient::insert_llm_drift_record(pool, record)
143                        .await
144                        .map_err(|e| {
145                            error!("Failed to insert LLM drift record: {:?}", e);
146                        });
147                }
148            }
149
150            RecordType::LLMMetric => {
151                debug!("LLM Metric record count: {:?}", records.len());
152                let llm_metric_records = records.to_llm_metric_records()?;
153
154                for chunk in llm_metric_records.chunks(Self::DEFAULT_BATCH_SIZE) {
155                    PostgresClient::insert_llm_metric_values_batch(pool, chunk)
156                        .await
157                        .map_err(|e| {
158                            error!("Failed to insert LLM metric records batch: {:?}", e);
159                            e
160                        })?;
161                }
162            }
163
164            _ => {
165                error!(
166                    "Unsupported record type for batch insert: {:?}",
167                    records.record_type()?
168                );
169                return Err(SqlError::UnsupportedBatchTypeError);
170            }
171        }
172
173        Ok(())
174    }
175}
176
177/// Runs database integratino tests
178/// Note - binned queries targeting custom intervals with long-term and short-term data are
179/// done in the scouter-server integration tests
180#[cfg(test)]
181mod tests {
182
183    use super::*;
184    use crate::sql::schema::User;
185    use chrono::Utc;
186    use potato_head::create_score_prompt;
187    use rand::Rng;
188    use scouter_semver::VersionType;
189    use scouter_settings::ObjectStorageSettings;
190    use scouter_types::llm::PaginationRequest;
191    use scouter_types::psi::{Bin, BinType, PsiDriftConfig, PsiFeatureDriftProfile};
192    use scouter_types::spc::SpcDriftProfile;
193    use scouter_types::*;
194    use serde_json::Value;
195    use sqlx::postgres::PgQueryResult;
196    use std::collections::BTreeMap;
197
198    const SPACE: &str = "space";
199    const NAME: &str = "name";
200    const VERSION: &str = "1.0.0";
201
202    pub async fn cleanup(pool: &Pool<Postgres>) {
203        sqlx::raw_sql(
204            r#"
205            DELETE
206            FROM scouter.spc_drift;
207
208            DELETE
209            FROM scouter.observability_metric;
210
211            DELETE
212            FROM scouter.custom_drift;
213
214            DELETE
215            FROM scouter.drift_alert;
216
217            DELETE
218            FROM scouter.drift_profile;
219
220            DELETE
221            FROM scouter.psi_drift;
222
223            DELETE
224            FROM scouter.user;
225
226            DELETE
227            FROM scouter.llm_drift_record;
228
229            DELETE
230            FROM scouter.llm_drift;
231            "#,
232        )
233        .fetch_all(pool)
234        .await
235        .unwrap();
236    }
237
238    pub async fn db_pool() -> Pool<Postgres> {
239        let pool = PostgresClient::create_db_pool(&DatabaseSettings::default())
240            .await
241            .unwrap();
242
243        cleanup(&pool).await;
244
245        pool
246    }
247
248    pub async fn insert_profile_to_db(
249        pool: &Pool<Postgres>,
250        profile: &DriftProfile,
251    ) -> PgQueryResult {
252        let base_args = profile.get_base_args();
253        let version = PostgresClient::get_next_profile_version(
254            pool,
255            &base_args,
256            VersionType::Minor,
257            None,
258            None,
259        )
260        .await
261        .unwrap();
262
263        let result = PostgresClient::insert_drift_profile(pool, profile, &base_args, &version)
264            .await
265            .unwrap();
266
267        result
268    }
269
270    #[tokio::test]
271    async fn test_postgres() {
272        let _pool = db_pool().await;
273    }
274
275    #[tokio::test]
276    async fn test_postgres_drift_alert() {
277        let pool = db_pool().await;
278
279        let timestamp = Utc::now();
280
281        for _ in 0..10 {
282            let task_info = DriftTaskInfo {
283                space: SPACE.to_string(),
284                name: NAME.to_string(),
285                version: VERSION.to_string(),
286                uid: "test".to_string(),
287                drift_type: DriftType::Spc,
288            };
289
290            let alert = (0..10)
291                .map(|i| (i.to_string(), i.to_string()))
292                .collect::<BTreeMap<String, String>>();
293
294            let result = PostgresClient::insert_drift_alert(
295                &pool,
296                &task_info,
297                "test",
298                &alert,
299                &DriftType::Spc,
300            )
301            .await
302            .unwrap();
303
304            assert_eq!(result.rows_affected(), 1);
305        }
306
307        // get alerts
308        let alert_request = DriftAlertRequest {
309            space: SPACE.to_string(),
310            name: NAME.to_string(),
311            version: VERSION.to_string(),
312            active: Some(true),
313            limit: None,
314            limit_datetime: None,
315        };
316
317        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
318            .await
319            .unwrap();
320        assert!(alerts.len() > 5);
321
322        // get alerts limit 1
323        let alert_request = DriftAlertRequest {
324            space: SPACE.to_string(),
325            name: NAME.to_string(),
326            version: VERSION.to_string(),
327            active: Some(true),
328            limit: Some(1),
329            limit_datetime: None,
330        };
331
332        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
333            .await
334            .unwrap();
335        assert_eq!(alerts.len(), 1);
336
337        // get alerts limit timestamp
338        let alert_request = DriftAlertRequest {
339            space: SPACE.to_string(),
340            name: NAME.to_string(),
341            version: VERSION.to_string(),
342            active: Some(true),
343            limit: None,
344            limit_datetime: Some(timestamp),
345        };
346
347        let alerts = PostgresClient::get_drift_alerts(&pool, &alert_request)
348            .await
349            .unwrap();
350        assert!(alerts.len() > 5);
351    }
352
353    #[tokio::test]
354    async fn test_postgres_spc_drift_record() {
355        let pool = db_pool().await;
356
357        let record1 = SpcServerRecord {
358            created_at: Utc::now(),
359            space: SPACE.to_string(),
360            name: NAME.to_string(),
361            version: VERSION.to_string(),
362            feature: "test".to_string(),
363            value: 1.0,
364        };
365
366        let record2 = SpcServerRecord {
367            created_at: Utc::now(),
368            space: SPACE.to_string(),
369            name: NAME.to_string(),
370            version: VERSION.to_string(),
371            feature: "test2".to_string(),
372            value: 2.0,
373        };
374
375        let result = PostgresClient::insert_spc_drift_records_batch(&pool, &[record1, record2])
376            .await
377            .unwrap();
378
379        assert_eq!(result.rows_affected(), 2);
380    }
381
382    #[tokio::test]
383    async fn test_postgres_bin_count() {
384        let pool = db_pool().await;
385
386        let record1 = PsiServerRecord {
387            created_at: Utc::now(),
388            space: SPACE.to_string(),
389            name: NAME.to_string(),
390            version: VERSION.to_string(),
391            feature: "test".to_string(),
392            bin_id: 1,
393            bin_count: 1,
394        };
395
396        let record2 = PsiServerRecord {
397            created_at: Utc::now(),
398            space: SPACE.to_string(),
399            name: NAME.to_string(),
400            version: VERSION.to_string(),
401            feature: "test2".to_string(),
402            bin_id: 2,
403            bin_count: 2,
404        };
405
406        let result = PostgresClient::insert_bin_counts_batch(&pool, &[record1, record2])
407            .await
408            .unwrap();
409
410        assert_eq!(result.rows_affected(), 2);
411    }
412
413    #[tokio::test]
414    async fn test_postgres_observability_record() {
415        let pool = db_pool().await;
416
417        let record = ObservabilityMetrics::default();
418
419        let result = PostgresClient::insert_observability_record(&pool, &record)
420            .await
421            .unwrap();
422
423        assert_eq!(result.rows_affected(), 1);
424    }
425
426    #[tokio::test]
427    async fn test_postgres_crud_drift_profile() {
428        let pool = db_pool().await;
429
430        let mut spc_profile = SpcDriftProfile::default();
431        let profile = DriftProfile::Spc(spc_profile.clone());
432
433        let result = insert_profile_to_db(&pool, &profile).await;
434        assert_eq!(result.rows_affected(), 1);
435
436        spc_profile.scouter_version = "test".to_string();
437
438        let result =
439            PostgresClient::update_drift_profile(&pool, &DriftProfile::Spc(spc_profile.clone()))
440                .await
441                .unwrap();
442
443        assert_eq!(result.rows_affected(), 1);
444
445        let profile = PostgresClient::get_drift_profile(
446            &pool,
447            &GetProfileRequest {
448                name: spc_profile.config.name.clone(),
449                space: spc_profile.config.space.clone(),
450                version: spc_profile.config.version.clone(),
451                drift_type: DriftType::Spc,
452            },
453        )
454        .await
455        .unwrap();
456
457        let deserialized = serde_json::from_value::<SpcDriftProfile>(profile.unwrap()).unwrap();
458
459        assert_eq!(deserialized, spc_profile);
460
461        PostgresClient::update_drift_profile_status(
462            &pool,
463            &ProfileStatusRequest {
464                name: spc_profile.config.name.clone(),
465                space: spc_profile.config.space.clone(),
466                version: spc_profile.config.version.clone(),
467                active: false,
468                drift_type: Some(DriftType::Spc),
469                deactivate_others: false,
470            },
471        )
472        .await
473        .unwrap();
474    }
475
476    #[tokio::test]
477    async fn test_postgres_get_features() {
478        let pool = db_pool().await;
479
480        let timestamp = Utc::now();
481
482        for _ in 0..10 {
483            let mut records = Vec::new();
484            for j in 0..10 {
485                let record = SpcServerRecord {
486                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
487                    space: SPACE.to_string(),
488                    name: NAME.to_string(),
489                    version: VERSION.to_string(),
490                    feature: format!("test{j}"),
491                    value: j as f64,
492                };
493
494                records.push(record);
495            }
496
497            let result = PostgresClient::insert_spc_drift_records_batch(&pool, &records)
498                .await
499                .unwrap();
500            assert_eq!(result.rows_affected(), records.len() as u64);
501        }
502
503        let service_info = ServiceInfo {
504            space: SPACE.to_string(),
505            name: NAME.to_string(),
506            version: VERSION.to_string(),
507        };
508
509        let features = PostgresClient::get_spc_features(&pool, &service_info)
510            .await
511            .unwrap();
512        assert_eq!(features.len(), 10);
513
514        let records =
515            PostgresClient::get_spc_drift_records(&pool, &service_info, &timestamp, &features)
516                .await
517                .unwrap();
518
519        assert_eq!(records.features.len(), 10);
520
521        let binned_records = PostgresClient::get_binned_spc_drift_records(
522            &pool,
523            &DriftRequest {
524                space: SPACE.to_string(),
525                name: NAME.to_string(),
526                version: VERSION.to_string(),
527                time_interval: TimeInterval::FiveMinutes,
528                max_data_points: 10,
529                drift_type: DriftType::Spc,
530                ..Default::default()
531            },
532            &DatabaseSettings::default().retention_period,
533            &ObjectStorageSettings::default(),
534        )
535        .await
536        .unwrap();
537
538        assert_eq!(binned_records.features.len(), 10);
539    }
540
541    #[tokio::test]
542    async fn test_postgres_bin_proportions() {
543        let pool = db_pool().await;
544
545        let timestamp = Utc::now();
546
547        let num_features = 3;
548        let num_bins = 5;
549
550        let features = (0..=num_features)
551            .map(|feature| {
552                let bins = (0..=num_bins)
553                    .map(|bind_id| Bin {
554                        id: bind_id,
555                        lower_limit: None,
556                        upper_limit: None,
557                        proportion: 0.0,
558                    })
559                    .collect();
560                let feature_name = format!("feature{feature}");
561                let feature_profile = PsiFeatureDriftProfile {
562                    id: feature_name.clone(),
563                    bins,
564                    timestamp,
565                    bin_type: BinType::Numeric,
566                };
567                (feature_name, feature_profile)
568            })
569            .collect();
570
571        let profile = &DriftProfile::Psi(psi::PsiDriftProfile::new(
572            features,
573            PsiDriftConfig {
574                space: SPACE.to_string(),
575                name: NAME.to_string(),
576                version: VERSION.to_string(),
577                ..Default::default()
578            },
579        ));
580        let _ = insert_profile_to_db(&pool, profile).await;
581
582        for feature in 0..num_features {
583            for bin in 0..=num_bins {
584                let mut records = Vec::new();
585                for j in 0..=100 {
586                    let record = PsiServerRecord {
587                        created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
588                        space: SPACE.to_string(),
589                        name: NAME.to_string(),
590                        version: VERSION.to_string(),
591                        feature: format!("feature{feature}"),
592                        bin_id: bin,
593                        bin_count: rand::rng().random_range(0..10),
594                    };
595
596                    records.push(record);
597                }
598                PostgresClient::insert_bin_counts_batch(&pool, &records)
599                    .await
600                    .unwrap();
601            }
602        }
603
604        let binned_records = PostgresClient::get_feature_distributions(
605            &pool,
606            &ServiceInfo {
607                space: SPACE.to_string(),
608                name: NAME.to_string(),
609                version: VERSION.to_string(),
610            },
611            &timestamp,
612            &["feature0".to_string()],
613        )
614        .await
615        .unwrap();
616
617        // assert binned_records.features["test"]["decile_1"] is around .5
618        let bin_proportion = binned_records
619            .distributions
620            .get("feature0")
621            .unwrap()
622            .bins
623            .get(&1)
624            .unwrap();
625
626        assert!(*bin_proportion > 0.1 && *bin_proportion < 0.2);
627
628        let binned_records = PostgresClient::get_binned_psi_drift_records(
629            &pool,
630            &DriftRequest {
631                space: SPACE.to_string(),
632                name: NAME.to_string(),
633                version: VERSION.to_string(),
634                time_interval: TimeInterval::OneHour,
635                max_data_points: 1000,
636                drift_type: DriftType::Psi,
637                ..Default::default()
638            },
639            &DatabaseSettings::default().retention_period,
640            &ObjectStorageSettings::default(),
641        )
642        .await
643        .unwrap();
644        //
645        assert_eq!(binned_records.len(), 3);
646    }
647
648    #[tokio::test]
649    async fn test_postgres_cru_custom_metric() {
650        let pool = db_pool().await;
651
652        let timestamp = Utc::now();
653
654        for i in 0..2 {
655            let mut records = Vec::new();
656            for j in 0..25 {
657                let record = CustomMetricServerRecord {
658                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
659                    space: SPACE.to_string(),
660                    name: NAME.to_string(),
661                    version: VERSION.to_string(),
662                    metric: format!("metric{i}"),
663                    value: rand::rng().random_range(0..10) as f64,
664                };
665                records.push(record);
666            }
667            let result = PostgresClient::insert_custom_metric_values_batch(&pool, &records)
668                .await
669                .unwrap();
670            assert_eq!(result.rows_affected(), 25);
671        }
672
673        // insert random record to test has statistics funcs handle single record
674        let record = CustomMetricServerRecord {
675            created_at: Utc::now(),
676            space: SPACE.to_string(),
677            name: NAME.to_string(),
678            version: VERSION.to_string(),
679            metric: "metric3".to_string(),
680            value: rand::rng().random_range(0..10) as f64,
681        };
682
683        let result = PostgresClient::insert_custom_metric_values_batch(&pool, &[record])
684            .await
685            .unwrap();
686        assert_eq!(result.rows_affected(), 1);
687
688        let metrics = PostgresClient::get_custom_metric_values(
689            &pool,
690            &ServiceInfo {
691                space: SPACE.to_string(),
692                name: NAME.to_string(),
693                version: VERSION.to_string(),
694            },
695            &timestamp,
696            &["metric1".to_string()],
697        )
698        .await
699        .unwrap();
700
701        assert_eq!(metrics.len(), 1);
702
703        let binned_records = PostgresClient::get_binned_custom_drift_records(
704            &pool,
705            &DriftRequest {
706                space: SPACE.to_string(),
707                name: NAME.to_string(),
708                version: VERSION.to_string(),
709                time_interval: TimeInterval::OneHour,
710                max_data_points: 1000,
711                drift_type: DriftType::Custom,
712                ..Default::default()
713            },
714            &DatabaseSettings::default().retention_period,
715            &ObjectStorageSettings::default(),
716        )
717        .await
718        .unwrap();
719        //
720        assert_eq!(binned_records.metrics.len(), 3);
721    }
722
723    #[tokio::test]
724    async fn test_postgres_user() {
725        let pool = db_pool().await;
726        let recovery_codes = vec!["recovery_code_1".to_string(), "recovery_code_2".to_string()];
727
728        // Create
729        let user = User::new(
730            "user".to_string(),
731            "pass".to_string(),
732            "email".to_string(),
733            recovery_codes,
734            None,
735            None,
736            None,
737            None,
738        );
739        PostgresClient::insert_user(&pool, &user).await.unwrap();
740
741        // Read
742        let mut user = PostgresClient::get_user(&pool, "user")
743            .await
744            .unwrap()
745            .unwrap();
746
747        assert_eq!(user.username, "user");
748        assert_eq!(user.group_permissions, vec!["user"]);
749        assert_eq!(user.email, "email");
750
751        // update user
752        user.active = false;
753        user.refresh_token = Some("token".to_string());
754
755        // Update
756        PostgresClient::update_user(&pool, &user).await.unwrap();
757        let user = PostgresClient::get_user(&pool, "user")
758            .await
759            .unwrap()
760            .unwrap();
761        assert!(!user.active);
762        assert_eq!(user.refresh_token.unwrap(), "token");
763
764        // get users
765        let users = PostgresClient::get_users(&pool).await.unwrap();
766        assert_eq!(users.len(), 1);
767
768        // get last admin
769        let is_last_admin = PostgresClient::is_last_admin(&pool, "user").await.unwrap();
770        assert!(!is_last_admin);
771
772        // delete
773        PostgresClient::delete_user(&pool, "user").await.unwrap();
774    }
775
776    #[tokio::test]
777    async fn test_postgres_llm_drift_record_insert_get() {
778        let pool = db_pool().await;
779
780        let input = "This is a test input";
781        let output = "This is a test response";
782        let prompt = create_score_prompt(None);
783
784        for j in 0..10 {
785            let context = serde_json::json!({
786                "input": input,
787                "response": output,
788            });
789            let record = LLMDriftServerRecord {
790                created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
791                space: SPACE.to_string(),
792                name: NAME.to_string(),
793                version: VERSION.to_string(),
794                prompt: Some(prompt.model_dump_value()),
795                context,
796                status: Status::Pending,
797                id: 0, // This will be set by the database
798                uid: "test".to_string(),
799                updated_at: None,
800                score: Value::Null,
801                processing_started_at: None,
802                processing_ended_at: None,
803                processing_duration: None,
804            };
805
806            let result = PostgresClient::insert_llm_drift_record(&pool, &record)
807                .await
808                .unwrap();
809
810            assert_eq!(result.rows_affected(), 1);
811        }
812
813        let service_info = ServiceInfo {
814            space: SPACE.to_string(),
815            name: NAME.to_string(),
816            version: VERSION.to_string(),
817        };
818
819        let features = PostgresClient::get_llm_drift_records(&pool, &service_info, None, None)
820            .await
821            .unwrap();
822        assert_eq!(features.len(), 10);
823
824        // get pending task
825        let pending_tasks = PostgresClient::get_pending_llm_drift_record(&pool)
826            .await
827            .unwrap();
828
829        // assert not empty
830        assert!(pending_tasks.is_some());
831
832        // get pending task with space, name, version
833        let task_input = &pending_tasks.as_ref().unwrap().context["input"];
834        assert_eq!(*task_input, "This is a test input".to_string());
835
836        // update pending task
837        PostgresClient::update_llm_drift_record_status(
838            &pool,
839            &pending_tasks.unwrap(),
840            Status::Processed,
841            Some(1),
842        )
843        .await
844        .unwrap();
845
846        // query processed tasks
847        let processed_tasks = PostgresClient::get_llm_drift_records(
848            &pool,
849            &service_info,
850            None,
851            Some(Status::Processed),
852        )
853        .await
854        .unwrap();
855
856        // assert not empty
857        assert_eq!(processed_tasks.len(), 1);
858    }
859
860    #[tokio::test]
861    async fn test_postgres_llm_drift_record_pagination() {
862        let pool = db_pool().await;
863
864        let input = "This is a test input";
865        let output = "This is a test response";
866        let prompt = create_score_prompt(None);
867
868        for j in 0..10 {
869            let context = serde_json::json!({
870                "input": input,
871                "response": output,
872            });
873            let record = LLMDriftServerRecord {
874                created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
875                space: SPACE.to_string(),
876                name: NAME.to_string(),
877                version: VERSION.to_string(),
878                prompt: Some(prompt.model_dump_value()),
879                context,
880                score: Value::Null,
881                status: Status::Pending,
882                id: 0, // This will be set by the database
883                uid: "test".to_string(),
884                updated_at: None,
885                processing_started_at: None,
886                processing_ended_at: None,
887                processing_duration: None,
888            };
889
890            let result = PostgresClient::insert_llm_drift_record(&pool, &record)
891                .await
892                .unwrap();
893
894            assert_eq!(result.rows_affected(), 1);
895        }
896
897        let service_info = ServiceInfo {
898            space: SPACE.to_string(),
899            name: NAME.to_string(),
900            version: VERSION.to_string(),
901        };
902
903        // Get paginated records (1st page)
904        let pagination = PaginationRequest {
905            limit: 5,
906            cursor: None, // Start from the beginning
907        };
908
909        let paginated_features = PostgresClient::get_llm_drift_records_pagination(
910            &pool,
911            &service_info,
912            None,
913            pagination,
914        )
915        .await
916        .unwrap();
917
918        assert_eq!(paginated_features.items.len(), 5);
919        assert!(paginated_features.next_cursor.is_some());
920
921        // get id of the most recent record in the first page
922        let last_record = paginated_features.items.first().unwrap();
923
924        // Get paginated records (2nd page)
925        let next_cursor = paginated_features.next_cursor.unwrap();
926        let pagination = PaginationRequest {
927            limit: 5,
928            cursor: Some(next_cursor),
929        };
930
931        let paginated_features = PostgresClient::get_llm_drift_records_pagination(
932            &pool,
933            &service_info,
934            None,
935            pagination,
936        )
937        .await
938        .unwrap();
939
940        assert_eq!(paginated_features.items.len(), 5);
941        assert!(paginated_features.next_cursor.is_none());
942
943        // get last record of the second page
944        let first_record = paginated_features.items.last().unwrap();
945
946        let diff = last_record.id - first_record.id + 1; // +1 because IDs are inclusive
947        assert!(diff == 10);
948    }
949
950    #[tokio::test]
951    async fn test_postgres_llm_metrics_insert_get() {
952        let pool = db_pool().await;
953
954        let timestamp = Utc::now();
955
956        for i in 0..2 {
957            let mut records = Vec::new();
958            for j in 0..25 {
959                let record = LLMMetricRecord {
960                    record_uid: format!("uid{i}{j}"),
961                    created_at: Utc::now() + chrono::Duration::microseconds(j as i64),
962                    space: SPACE.to_string(),
963                    name: NAME.to_string(),
964                    version: VERSION.to_string(),
965                    metric: format!("metric{i}"),
966                    value: rand::rng().random_range(0..10) as f64,
967                };
968                records.push(record);
969            }
970            let result = PostgresClient::insert_llm_metric_values_batch(&pool, &records)
971                .await
972                .unwrap();
973            assert_eq!(result.rows_affected(), 25);
974        }
975
976        let metrics = PostgresClient::get_llm_metric_values(
977            &pool,
978            &ServiceInfo {
979                space: SPACE.to_string(),
980                name: NAME.to_string(),
981                version: VERSION.to_string(),
982            },
983            &timestamp,
984            &["metric1".to_string()],
985        )
986        .await
987        .unwrap();
988
989        assert_eq!(metrics.len(), 1);
990        let binned_records = PostgresClient::get_binned_llm_metric_values(
991            &pool,
992            &DriftRequest {
993                space: SPACE.to_string(),
994                name: NAME.to_string(),
995                version: VERSION.to_string(),
996                time_interval: TimeInterval::OneHour,
997                max_data_points: 1000,
998                drift_type: DriftType::LLM,
999                ..Default::default()
1000            },
1001            &DatabaseSettings::default().retention_period,
1002            &ObjectStorageSettings::default(),
1003        )
1004        .await
1005        .unwrap();
1006        //
1007        assert_eq!(binned_records.metrics.len(), 2);
1008    }
1009}