1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3use crate::sql::schema::BinnedMetricWrapper;
4use crate::sql::utils::split_custom_interval;
5use async_trait::async_trait;
6use chrono::{DateTime, Duration, Utc};
7use scouter_dataframe::parquet::BinnedMetricsExtractor;
8use scouter_dataframe::parquet::ParquetDataFrame;
9use scouter_settings::ObjectStorageSettings;
10use scouter_types::contracts::DriftRequest;
11use scouter_types::BoxedEvalRecord;
12use scouter_types::EvalRecord;
13use scouter_types::EvalTaskResult;
14use scouter_types::GenAIEvalWorkflowPaginationResponse;
15use scouter_types::GenAIEvalWorkflowResult;
16use scouter_types::Status;
17use scouter_types::{
18 BinnedMetrics, EvalRecordPaginationRequest, EvalRecordPaginationResponse, RecordCursor,
19 RecordType,
20};
21use sqlx::types::Json;
22use sqlx::{postgres::PgQueryResult, Pool, Postgres, Row};
23use std::collections::HashMap;
24use tracing::error;
25use tracing::{debug, instrument};
26
27#[async_trait]
28pub trait GenAIDriftSqlLogic {
29 async fn insert_genai_eval_record(
36 pool: &Pool<Postgres>,
37 record: BoxedEvalRecord,
38 entity_id: &i32,
39 ) -> Result<PgQueryResult, SqlError> {
40 let query = Queries::InsertEvalRecord.get_query();
41
42 sqlx::query(query)
43 .bind(record.record.uid)
44 .bind(record.record.created_at)
45 .bind(entity_id)
46 .bind(Json(record.record.context))
47 .bind(&record.record.record_id)
48 .bind(&record.record.session_id)
49 .bind(record.record.trace_id.map(|t| t.as_bytes().to_vec()))
50 .bind(&record.record.tags)
51 .execute(pool)
52 .await
53 .map_err(SqlError::SqlxError)
54 }
55
56 async fn insert_genai_eval_workflow_record(
63 pool: &Pool<Postgres>,
64 record: &GenAIEvalWorkflowResult,
65 entity_id: &i32,
66 ) -> Result<PgQueryResult, SqlError> {
67 let query = Queries::InsertGenAIWorkflowResult.get_query();
68
69 sqlx::query(query)
70 .bind(record.created_at)
71 .bind(record.record_uid.as_str())
72 .bind(entity_id)
73 .bind(record.total_tasks)
74 .bind(record.passed_tasks)
75 .bind(record.failed_tasks)
76 .bind(record.pass_rate)
77 .bind(record.duration_ms)
78 .bind(Json(&record.execution_plan))
79 .execute(pool)
80 .await
81 .map_err(SqlError::SqlxError)
82 }
83
84 async fn insert_eval_task_results_batch(
87 pool: &Pool<Postgres>,
88 records: &[EvalTaskResult], entity_id: &i32,
90 ) -> Result<sqlx::postgres::PgQueryResult, SqlError> {
91 if records.is_empty() {
92 return Err(SqlError::EmptyBatchError);
93 }
94
95 let n = records.len();
96
97 let mut created_ats = Vec::with_capacity(n);
99 let mut start_times = Vec::with_capacity(n);
100 let mut end_times = Vec::with_capacity(n);
101 let mut record_uids = Vec::with_capacity(n);
102 let mut entity_ids = Vec::with_capacity(n);
103 let mut task_ids = Vec::with_capacity(n);
104 let mut task_types = Vec::with_capacity(n);
105 let mut passed_flags = Vec::with_capacity(n);
106 let mut values = Vec::with_capacity(n);
107 let mut assertions = Vec::with_capacity(n);
108 let mut operators = Vec::with_capacity(n);
109 let mut expected_jsons = Vec::with_capacity(n);
110 let mut actual_jsons = Vec::with_capacity(n);
111 let mut messages = Vec::with_capacity(n);
112 let mut condition = Vec::with_capacity(n);
113 let mut stage = Vec::with_capacity(n);
114
115 for r in records {
116 created_ats.push(r.created_at);
117 start_times.push(r.start_time);
118 end_times.push(r.end_time);
119 record_uids.push(&r.record_uid);
120 entity_ids.push(entity_id);
121 task_ids.push(&r.task_id);
122 task_types.push(r.task_type.as_str());
123 passed_flags.push(r.passed);
124 values.push(r.value);
125 assertions.push(Json(r.assertion()));
126 operators.push(r.operator.as_str());
127 expected_jsons.push(Json(&r.expected));
128 actual_jsons.push(Json(&r.actual));
129 messages.push(&r.message);
130 condition.push(r.condition);
131 stage.push(r.stage);
132 }
133
134 let query = Queries::InsertGenAITaskResultsBatch.get_query();
135
136 sqlx::query(query)
137 .bind(&created_ats)
138 .bind(&start_times)
139 .bind(&end_times)
140 .bind(&record_uids)
141 .bind(&entity_ids)
142 .bind(&task_ids)
143 .bind(&task_types)
144 .bind(&passed_flags)
145 .bind(&values)
146 .bind(&assertions)
147 .bind(&operators)
148 .bind(&expected_jsons)
149 .bind(&actual_jsons)
150 .bind(&messages)
151 .bind(&condition)
152 .bind(&stage)
153 .execute(pool)
154 .await
155 .map_err(SqlError::SqlxError)
156 }
157
158 async fn get_genai_eval_records(
159 pool: &Pool<Postgres>,
160 limit_datetime: Option<&DateTime<Utc>>,
161 status: Option<Status>,
162 entity_id: &i32,
163 ) -> Result<Vec<EvalRecord>, SqlError> {
164 let mut query_string = Queries::GetEvalRecords.get_query().to_string();
165 let mut bind_count = 1;
166
167 if limit_datetime.is_some() {
168 bind_count += 1;
169 query_string.push_str(&format!(" AND created_at > ${bind_count}"));
170 }
171
172 let status_value = status.as_ref().and_then(|s| s.as_str());
173 if status_value.is_some() {
174 bind_count += 1;
175 query_string.push_str(&format!(" AND status = ${bind_count}"));
176 }
177
178 let mut query = sqlx::query_as::<_, EvalRecord>(&query_string).bind(entity_id);
179
180 if let Some(datetime) = limit_datetime {
181 query = query.bind(datetime);
182 }
183 if let Some(status) = status_value {
185 query = query.bind(status);
186 }
187
188 let records = query.fetch_all(pool).await.map_err(SqlError::SqlxError)?;
189
190 Ok(records
191 .into_iter()
192 .map(|mut r| {
193 r.mask_sensitive_data();
194 r
195 })
196 .collect())
197 }
198
199 #[instrument(skip_all)]
209 async fn get_paginated_genai_eval_records(
210 pool: &Pool<Postgres>,
211 params: &EvalRecordPaginationRequest,
212 entity_id: &i32,
213 ) -> Result<EvalRecordPaginationResponse, SqlError> {
214 let query = Queries::GetPaginatedEvalRecords.get_query();
215 let limit = params.limit.unwrap_or(50);
216 let direction = params.direction.as_deref().unwrap_or("next");
217
218 let mut items: Vec<EvalRecord> = sqlx::query_as(query)
219 .bind(entity_id)
220 .bind(params.status.as_ref().and_then(|s| s.as_str()))
221 .bind(params.cursor_created_at)
222 .bind(direction)
223 .bind(params.cursor_id)
224 .bind(limit)
225 .bind(params.start_datetime)
226 .bind(params.end_datetime)
227 .fetch_all(pool)
228 .await
229 .map_err(SqlError::SqlxError)?;
230
231 let has_more = items.len() > limit as usize;
232
233 if has_more {
234 items.pop();
235 }
236
237 let (has_next, next_cursor, has_previous, previous_cursor) = match direction {
238 "previous" => {
239 items.reverse();
240
241 let previous_cursor = if has_more {
242 items.first().map(|first| RecordCursor {
243 created_at: first.created_at,
244 id: first.id,
245 })
246 } else {
247 None
248 };
249
250 let next_cursor = items.last().map(|last| RecordCursor {
251 created_at: last.created_at,
252 id: last.id,
253 });
254
255 (
256 params.cursor_created_at.is_some(),
257 next_cursor,
258 has_more,
259 previous_cursor,
260 )
261 }
262 _ => {
263 let next_cursor = if has_more {
265 items.last().map(|last| RecordCursor {
266 created_at: last.created_at,
267 id: last.id,
268 })
269 } else {
270 None
271 };
272
273 let previous_cursor = items.first().map(|first| RecordCursor {
274 created_at: first.created_at,
275 id: first.id,
276 });
277
278 (
279 has_more,
280 next_cursor,
281 params.cursor_created_at.is_some(),
282 previous_cursor,
283 )
284 }
285 };
286
287 let public_items = items
288 .into_iter()
289 .map(|mut r| {
290 r.mask_sensitive_data();
291 r
292 })
293 .collect();
294
295 Ok(EvalRecordPaginationResponse {
296 items: public_items,
297 has_next,
298 next_cursor,
299 has_previous,
300 previous_cursor,
301 })
302 }
303
304 async fn get_genai_eval_task(
311 pool: &Pool<Postgres>,
312 record_uid: &str,
313 ) -> Result<Vec<EvalTaskResult>, SqlError> {
314 let query = Queries::GetGenAIEvalTasks.get_query();
315 let tasks: Result<Vec<EvalTaskResult>, SqlError> = sqlx::query_as(query)
316 .bind(record_uid)
317 .fetch_all(pool)
318 .await
319 .map_err(SqlError::SqlxError);
320
321 tasks
322 }
323
324 #[instrument(skip_all)]
334 async fn get_paginated_genai_eval_workflow_records(
335 pool: &Pool<Postgres>,
336 params: &EvalRecordPaginationRequest,
337 entity_id: &i32,
338 ) -> Result<GenAIEvalWorkflowPaginationResponse, SqlError> {
339 let query = Queries::GetPaginatedGenAIEvalWorkflow.get_query();
340 let limit = params.limit.unwrap_or(50);
341 let direction = params.direction.as_deref().unwrap_or("next");
342
343 let mut items: Vec<GenAIEvalWorkflowResult> = sqlx::query_as(query)
344 .bind(entity_id)
345 .bind(params.cursor_created_at)
346 .bind(direction)
347 .bind(params.cursor_id)
348 .bind(limit)
349 .bind(params.start_datetime)
350 .bind(params.end_datetime)
351 .fetch_all(pool)
352 .await
353 .map_err(SqlError::SqlxError)?;
354
355 let has_more = items.len() > limit as usize;
356
357 if has_more {
358 items.pop();
359 }
360
361 let (has_next, next_cursor, has_previous, previous_cursor) = match direction {
362 "previous" => {
363 items.reverse();
364
365 let previous_cursor = if has_more {
366 items.first().map(|first| RecordCursor {
367 created_at: first.created_at,
368 id: first.id,
369 })
370 } else {
371 None
372 };
373
374 let next_cursor = items.last().map(|last| RecordCursor {
375 created_at: last.created_at,
376 id: last.id,
377 });
378
379 (
380 params.cursor_created_at.is_some(),
381 next_cursor,
382 has_more,
383 previous_cursor,
384 )
385 }
386 _ => {
387 let next_cursor = if has_more {
389 items.last().map(|last| RecordCursor {
390 created_at: last.created_at,
391 id: last.id,
392 })
393 } else {
394 None
395 };
396
397 let previous_cursor = items.first().map(|first| RecordCursor {
398 created_at: first.created_at,
399 id: first.id,
400 });
401
402 (
403 has_more,
404 next_cursor,
405 params.cursor_created_at.is_some(),
406 previous_cursor,
407 )
408 }
409 };
410
411 let public_items = items
412 .into_iter()
413 .map(|mut r| {
414 r.mask_sensitive_data();
415 r
416 })
417 .collect();
418
419 Ok(GenAIEvalWorkflowPaginationResponse {
420 items: public_items,
421 has_next,
422 next_cursor,
423 has_previous,
424 previous_cursor,
425 })
426 }
427
428 #[instrument(skip_all)]
437 async fn get_genai_task_values(
438 pool: &Pool<Postgres>,
439 limit_datetime: &DateTime<Utc>,
440 metrics: &[String],
441 entity_id: &i32,
442 ) -> Result<HashMap<String, f64>, SqlError> {
443 let query = Queries::GetGenAITaskValues.get_query();
444
445 let records = sqlx::query(query)
446 .bind(limit_datetime)
447 .bind(entity_id)
448 .bind(metrics)
449 .fetch_all(pool)
450 .await
451 .map_err(SqlError::SqlxError)?;
452
453 let metric_map = records
454 .into_iter()
455 .map(|row| {
456 let metric = row.get("metric");
457 let value = row.get("value");
458 (metric, value)
459 })
460 .collect();
461
462 Ok(metric_map)
463 }
464
465 #[instrument(skip_all)]
473 async fn get_genai_workflow_value(
474 pool: &Pool<Postgres>,
475 limit_datetime: &DateTime<Utc>,
476 entity_id: &i32,
477 ) -> Result<Option<f64>, SqlError> {
478 let query = Queries::GetGenAIWorkflowValues.get_query();
479
480 let records = sqlx::query(query)
481 .bind(limit_datetime)
482 .bind(entity_id)
483 .fetch_optional(pool)
484 .await
485 .inspect_err(|e| {
486 error!("Error fetching GenAI workflow values: {:?}", e);
487 })?;
488
489 Ok(records.and_then(|r| r.try_get("value").ok()))
490 }
491
492 #[instrument(skip_all)]
503 async fn get_binned_workflow_records(
504 pool: &Pool<Postgres>,
505 params: &DriftRequest,
506 start_dt: DateTime<Utc>,
507 end_dt: DateTime<Utc>,
508 entity_id: &i32,
509 ) -> Result<BinnedMetrics, SqlError> {
510 let minutes = end_dt.signed_duration_since(start_dt).num_minutes() as f64;
511 let bin = minutes / params.max_data_points as f64;
512
513 let query = Queries::GetGenAIWorkflowBinnedMetrics.get_query();
514
515 let records: Vec<BinnedMetricWrapper> = sqlx::query_as(query)
516 .bind(bin)
517 .bind(start_dt)
518 .bind(end_dt)
519 .bind(entity_id)
520 .fetch_all(pool)
521 .await
522 .map_err(SqlError::SqlxError)?;
523
524 Ok(BinnedMetrics::from_vec(
525 records.into_iter().map(|wrapper| wrapper.0).collect(),
526 ))
527 }
528
529 #[instrument(skip_all)]
540 async fn get_binned_task_records(
541 pool: &Pool<Postgres>,
542 params: &DriftRequest,
543 start_dt: DateTime<Utc>,
544 end_dt: DateTime<Utc>,
545 entity_id: &i32,
546 ) -> Result<BinnedMetrics, SqlError> {
547 let minutes = end_dt.signed_duration_since(start_dt).num_minutes() as f64;
548 let bin = minutes / params.max_data_points as f64;
549
550 let query = Queries::GetGenAITaskBinnedMetrics.get_query();
551
552 let records: Vec<BinnedMetricWrapper> = sqlx::query_as(query)
553 .bind(bin)
554 .bind(start_dt)
555 .bind(end_dt)
556 .bind(entity_id)
557 .fetch_all(pool)
558 .await
559 .map_err(SqlError::SqlxError)?;
560
561 Ok(BinnedMetrics::from_vec(
562 records.into_iter().map(|wrapper| wrapper.0).collect(),
563 ))
564 }
565
566 fn merge_feature_results(
568 results: BinnedMetrics,
569 map: &mut BinnedMetrics,
570 ) -> Result<(), SqlError> {
571 for (name, metric) in results.metrics {
572 let metric_clone = metric.clone();
573 map.metrics
574 .entry(name)
575 .and_modify(|existing| {
576 existing.created_at.extend(metric_clone.created_at);
577 existing.stats.extend(metric_clone.stats);
578 })
579 .or_insert(metric);
580 }
581
582 Ok(())
583 }
584
585 #[instrument(skip_all)]
599 async fn get_archived_task_records(
600 params: &DriftRequest,
601 begin: DateTime<Utc>,
602 end: DateTime<Utc>,
603 minutes: i32,
604 storage_settings: &ObjectStorageSettings,
605 entity_id: &i32,
606 ) -> Result<BinnedMetrics, SqlError> {
607 debug!("Getting archived GenAI metrics for params: {:?}", params);
608 let path = format!("{}/{}", params.uid, RecordType::GenAITask);
609 let bin = minutes as f64 / params.max_data_points as f64;
610 let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::GenAITask)?
611 .get_binned_metrics(&path, &bin, &begin, &end, entity_id)
612 .await
613 .inspect_err(|e| {
614 error!("Failed to get archived GenAI metrics: {:?}", e);
615 })?;
616
617 Ok(BinnedMetricsExtractor::dataframe_to_binned_metrics(archived_df).await?)
618 }
619
620 #[instrument(skip_all)]
634 async fn get_archived_workflow_records(
635 params: &DriftRequest,
636 begin: DateTime<Utc>,
637 end: DateTime<Utc>,
638 minutes: i32,
639 storage_settings: &ObjectStorageSettings,
640 entity_id: &i32,
641 ) -> Result<BinnedMetrics, SqlError> {
642 debug!("Getting archived GenAI metrics for params: {:?}", params);
643 let path = format!("{}/{}", params.uid, RecordType::GenAIWorkflow);
644 let bin = minutes as f64 / params.max_data_points as f64;
645 let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::GenAIWorkflow)?
646 .get_binned_metrics(&path, &bin, &begin, &end, entity_id)
647 .await
648 .inspect_err(|e| {
649 error!("Failed to get archived GenAI metrics: {:?}", e);
650 })?;
651
652 Ok(BinnedMetricsExtractor::dataframe_to_binned_metrics(archived_df).await?)
653 }
654
655 #[instrument(skip_all)]
665 async fn get_binned_genai_task_values(
666 pool: &Pool<Postgres>,
667 params: &DriftRequest,
668 retention_period: &i32,
669 storage_settings: &ObjectStorageSettings,
670 entity_id: &i32,
671 ) -> Result<BinnedMetrics, SqlError> {
672 debug!("Getting binned task drift records for {:?}", params);
673
674 if !params.has_custom_interval() {
675 debug!("No custom interval provided, using default");
676 let (start_dt, end_dt) = params.time_interval.to_begin_end_times()?;
677 return Self::get_binned_task_records(pool, params, start_dt, end_dt, entity_id).await;
678 }
679
680 debug!("Custom interval provided, using custom interval");
681 let interval = params.clone().to_custom_interval().unwrap();
682 let timestamps = split_custom_interval(interval.begin, interval.end, retention_period)?;
683 let mut custom_metric_map = BinnedMetrics::default();
684
685 if let Some((active_begin, active_end)) = timestamps.active_range {
687 let current_results =
688 Self::get_binned_task_records(pool, params, active_begin, active_end, entity_id)
689 .await?;
690 Self::merge_feature_results(current_results, &mut custom_metric_map)?;
691 }
692
693 if let Some((archive_begin, archive_end)) = timestamps.archived_range {
695 if let Some(archived_minutes) = timestamps.archived_minutes {
696 let archived_results = Self::get_archived_task_records(
697 params,
698 archive_begin,
699 archive_end,
700 archived_minutes,
701 storage_settings,
702 entity_id,
703 )
704 .await?;
705 Self::merge_feature_results(archived_results, &mut custom_metric_map)?;
706 }
707 }
708
709 Ok(custom_metric_map)
710 }
711
712 #[instrument(skip_all)]
713 async fn get_binned_genai_workflow_values(
714 pool: &Pool<Postgres>,
715 params: &DriftRequest,
716 retention_period: &i32,
717 storage_settings: &ObjectStorageSettings,
718 entity_id: &i32,
719 ) -> Result<BinnedMetrics, SqlError> {
720 debug!("Getting binned workflow drift records for {:?}", params);
721
722 if !params.has_custom_interval() {
723 debug!("No custom interval provided, using default");
724 let (start_dt, end_dt) = params.time_interval.to_begin_end_times()?;
725 return Self::get_binned_workflow_records(pool, params, start_dt, end_dt, entity_id)
726 .await;
727 }
728
729 debug!("Custom interval provided, using custom interval");
730 let interval = params.clone().to_custom_interval().unwrap();
731 let timestamps = split_custom_interval(interval.begin, interval.end, retention_period)?;
732 let mut custom_metric_map = BinnedMetrics::default();
733
734 if let Some((active_begin, active_end)) = timestamps.active_range {
736 let current_results = Self::get_binned_workflow_records(
737 pool,
738 params,
739 active_begin,
740 active_end,
741 entity_id,
742 )
743 .await?;
744 Self::merge_feature_results(current_results, &mut custom_metric_map)?;
745 }
746
747 if let Some((archive_begin, archive_end)) = timestamps.archived_range {
749 if let Some(archived_minutes) = timestamps.archived_minutes {
750 let archived_results = Self::get_archived_workflow_records(
751 params,
752 archive_begin,
753 archive_end,
754 archived_minutes,
755 storage_settings,
756 entity_id,
757 )
758 .await?;
759 Self::merge_feature_results(archived_results, &mut custom_metric_map)?;
760 }
761 }
762
763 debug!(
764 "Custom metric map length: {:?}",
765 custom_metric_map.metrics.len()
766 );
767
768 Ok(custom_metric_map)
769 }
770
771 async fn get_pending_genai_eval_record(
773 pool: &Pool<Postgres>,
774 ) -> Result<Option<EvalRecord>, SqlError> {
775 let query = Queries::GetPendingGenAIEvalTask.get_query();
776 let result: Option<EvalRecord> = sqlx::query_as(query)
777 .fetch_optional(pool)
778 .await
779 .map_err(SqlError::SqlxError)?;
780
781 debug!("Fetched pending GenAI drift record: {:?}", result);
782
783 Ok(result)
784 }
785
786 #[instrument(skip_all)]
787 async fn update_genai_eval_record_status(
788 pool: &Pool<Postgres>,
789 record: &EvalRecord,
790 status: Status,
791 workflow_duration: &i64,
792 ) -> Result<(), SqlError> {
793 let query = Queries::UpdateGenAIEvalTask.get_query();
794 let _query_result = sqlx::query(query)
795 .bind(status.as_str())
796 .bind(workflow_duration)
797 .bind(&record.uid)
798 .execute(pool)
799 .await
800 .inspect_err(|e| {
801 error!("Failed to update GenAI drift record status: {:?}", e);
802 })?;
803
804 Ok(())
805 }
806
807 #[instrument(skip_all)]
808 async fn reschedule_genai_eval_record(
809 pool: &Pool<Postgres>,
810 uid: &str,
811 delay: Duration,
812 ) -> Result<(), SqlError> {
813 let scheduled_at = Utc::now() + delay;
814
815 let query = Queries::RescheduleEvalRecord.get_query();
816 sqlx::query(query)
817 .bind(scheduled_at)
818 .bind(uid)
819 .execute(pool)
820 .await
821 .inspect_err(|e| {
822 error!("Failed to reschedule GenAI eval record: {:?}", e);
823 })?;
824
825 Ok(())
826 }
827}