1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3use crate::sql::schema::BinnedMetricWrapper;
4use crate::sql::utils::split_custom_interval;
5use async_trait::async_trait;
6use chrono::{DateTime, Duration, Utc};
7use scouter_dataframe::parquet::BinnedMetricsExtractor;
8use scouter_dataframe::parquet::ParquetDataFrame;
9use scouter_settings::ObjectStorageSettings;
10use scouter_types::contracts::DriftRequest;
11use scouter_types::BoxedGenAIEvalRecord;
12use scouter_types::GenAIEvalRecord;
13use scouter_types::GenAIEvalTaskResult;
14use scouter_types::GenAIEvalWorkflowPaginationResponse;
15use scouter_types::GenAIEvalWorkflowResult;
16use scouter_types::Status;
17use scouter_types::{
18 BinnedMetrics, GenAIEvalRecordPaginationRequest, GenAIEvalRecordPaginationResponse,
19 RecordCursor, RecordType,
20};
21use sqlx::types::Json;
22use sqlx::{postgres::PgQueryResult, Pool, Postgres, Row};
23use std::collections::HashMap;
24use tracing::error;
25use tracing::{debug, instrument};
26
27#[async_trait]
28pub trait GenAIDriftSqlLogic {
29 async fn insert_genai_eval_record(
36 pool: &Pool<Postgres>,
37 record: BoxedGenAIEvalRecord,
38 entity_id: &i32,
39 ) -> Result<PgQueryResult, SqlError> {
40 let query = Queries::InsertGenAIEvalRecord.get_query();
41
42 sqlx::query(query)
43 .bind(record.record.uid)
44 .bind(record.record.created_at)
45 .bind(entity_id)
46 .bind(Json(record.record.context))
47 .bind(&record.record.record_id)
48 .bind(&record.record.session_id)
49 .bind(record.record.trace_id.map(|t| t.as_bytes().to_vec()))
50 .execute(pool)
51 .await
52 .map_err(SqlError::SqlxError)
53 }
54
55 async fn insert_genai_eval_workflow_record(
62 pool: &Pool<Postgres>,
63 record: &GenAIEvalWorkflowResult,
64 entity_id: &i32,
65 ) -> Result<PgQueryResult, SqlError> {
66 let query = Queries::InsertGenAIWorkflowResult.get_query();
67
68 sqlx::query(query)
69 .bind(record.created_at)
70 .bind(record.record_uid.as_str())
71 .bind(entity_id)
72 .bind(record.total_tasks)
73 .bind(record.passed_tasks)
74 .bind(record.failed_tasks)
75 .bind(record.pass_rate)
76 .bind(record.duration_ms)
77 .bind(Json(&record.execution_plan))
78 .execute(pool)
79 .await
80 .map_err(SqlError::SqlxError)
81 }
82
83 async fn insert_eval_task_results_batch(
86 pool: &Pool<Postgres>,
87 records: &[GenAIEvalTaskResult], entity_id: &i32,
89 ) -> Result<sqlx::postgres::PgQueryResult, SqlError> {
90 if records.is_empty() {
91 return Err(SqlError::EmptyBatchError);
92 }
93
94 let n = records.len();
95
96 let mut created_ats = Vec::with_capacity(n);
98 let mut start_times = Vec::with_capacity(n);
99 let mut end_times = Vec::with_capacity(n);
100 let mut record_uids = Vec::with_capacity(n);
101 let mut entity_ids = Vec::with_capacity(n);
102 let mut task_ids = Vec::with_capacity(n);
103 let mut task_types = Vec::with_capacity(n);
104 let mut passed_flags = Vec::with_capacity(n);
105 let mut values = Vec::with_capacity(n);
106 let mut assertions = Vec::with_capacity(n);
107 let mut operators = Vec::with_capacity(n);
108 let mut expected_jsons = Vec::with_capacity(n);
109 let mut actual_jsons = Vec::with_capacity(n);
110 let mut messages = Vec::with_capacity(n);
111 let mut condition = Vec::with_capacity(n);
112 let mut stage = Vec::with_capacity(n);
113
114 for r in records {
115 created_ats.push(r.created_at);
116 start_times.push(r.start_time);
117 end_times.push(r.end_time);
118 record_uids.push(&r.record_uid);
119 entity_ids.push(entity_id);
120 task_ids.push(&r.task_id);
121 task_types.push(r.task_type.as_str());
122 passed_flags.push(r.passed);
123 values.push(r.value);
124 assertions.push(Json(r.assertion()));
125 operators.push(r.operator.as_str());
126 expected_jsons.push(Json(&r.expected));
127 actual_jsons.push(Json(&r.actual));
128 messages.push(&r.message);
129 condition.push(r.condition);
130 stage.push(r.stage);
131 }
132
133 let query = Queries::InsertGenAITaskResultsBatch.get_query();
134
135 sqlx::query(query)
136 .bind(&created_ats)
137 .bind(&start_times)
138 .bind(&end_times)
139 .bind(&record_uids)
140 .bind(&entity_ids)
141 .bind(&task_ids)
142 .bind(&task_types)
143 .bind(&passed_flags)
144 .bind(&values)
145 .bind(&assertions)
146 .bind(&operators)
147 .bind(&expected_jsons)
148 .bind(&actual_jsons)
149 .bind(&messages)
150 .bind(&condition)
151 .bind(&stage)
152 .execute(pool)
153 .await
154 .map_err(SqlError::SqlxError)
155 }
156
157 async fn get_genai_eval_records(
158 pool: &Pool<Postgres>,
159 limit_datetime: Option<&DateTime<Utc>>,
160 status: Option<Status>,
161 entity_id: &i32,
162 ) -> Result<Vec<GenAIEvalRecord>, SqlError> {
163 let mut query_string = Queries::GetGenAIEvalRecords.get_query().to_string();
164 let mut bind_count = 1;
165
166 if limit_datetime.is_some() {
167 bind_count += 1;
168 query_string.push_str(&format!(" AND created_at > ${bind_count}"));
169 }
170
171 let status_value = status.as_ref().and_then(|s| s.as_str());
172 if status_value.is_some() {
173 bind_count += 1;
174 query_string.push_str(&format!(" AND status = ${bind_count}"));
175 }
176
177 let mut query = sqlx::query_as::<_, GenAIEvalRecord>(&query_string).bind(entity_id);
178
179 if let Some(datetime) = limit_datetime {
180 query = query.bind(datetime);
181 }
182 if let Some(status) = status_value {
184 query = query.bind(status);
185 }
186
187 let records = query.fetch_all(pool).await.map_err(SqlError::SqlxError)?;
188
189 Ok(records
190 .into_iter()
191 .map(|mut r| {
192 r.mask_sensitive_data();
193 r
194 })
195 .collect())
196 }
197
198 #[instrument(skip_all)]
208 async fn get_paginated_genai_eval_records(
209 pool: &Pool<Postgres>,
210 params: &GenAIEvalRecordPaginationRequest,
211 entity_id: &i32,
212 ) -> Result<GenAIEvalRecordPaginationResponse, SqlError> {
213 let query = Queries::GetPaginatedGenAIEvalRecords.get_query();
214 let limit = params.limit.unwrap_or(50);
215 let direction = params.direction.as_deref().unwrap_or("next");
216
217 let mut items: Vec<GenAIEvalRecord> = sqlx::query_as(query)
218 .bind(entity_id)
219 .bind(params.status.as_ref().and_then(|s| s.as_str()))
220 .bind(params.cursor_created_at)
221 .bind(direction)
222 .bind(params.cursor_id)
223 .bind(limit)
224 .bind(params.start_datetime)
225 .bind(params.end_datetime)
226 .fetch_all(pool)
227 .await
228 .map_err(SqlError::SqlxError)?;
229
230 let has_more = items.len() > limit as usize;
231
232 if has_more {
233 items.pop();
234 }
235
236 let (has_next, next_cursor, has_previous, previous_cursor) = match direction {
237 "previous" => {
238 items.reverse();
239
240 let previous_cursor = if has_more {
241 items.first().map(|first| RecordCursor {
242 created_at: first.created_at,
243 id: first.id,
244 })
245 } else {
246 None
247 };
248
249 let next_cursor = items.last().map(|last| RecordCursor {
250 created_at: last.created_at,
251 id: last.id,
252 });
253
254 (
255 params.cursor_created_at.is_some(),
256 next_cursor,
257 has_more,
258 previous_cursor,
259 )
260 }
261 _ => {
262 let next_cursor = if has_more {
264 items.last().map(|last| RecordCursor {
265 created_at: last.created_at,
266 id: last.id,
267 })
268 } else {
269 None
270 };
271
272 let previous_cursor = items.first().map(|first| RecordCursor {
273 created_at: first.created_at,
274 id: first.id,
275 });
276
277 (
278 has_more,
279 next_cursor,
280 params.cursor_created_at.is_some(),
281 previous_cursor,
282 )
283 }
284 };
285
286 let public_items = items
287 .into_iter()
288 .map(|mut r| {
289 r.mask_sensitive_data();
290 r
291 })
292 .collect();
293
294 Ok(GenAIEvalRecordPaginationResponse {
295 items: public_items,
296 has_next,
297 next_cursor,
298 has_previous,
299 previous_cursor,
300 })
301 }
302
303 async fn get_genai_eval_task(
310 pool: &Pool<Postgres>,
311 record_uid: &str,
312 ) -> Result<Vec<GenAIEvalTaskResult>, SqlError> {
313 let query = Queries::GetGenAIEvalTasks.get_query();
314 let tasks: Result<Vec<GenAIEvalTaskResult>, SqlError> = sqlx::query_as(query)
315 .bind(record_uid)
316 .fetch_all(pool)
317 .await
318 .map_err(SqlError::SqlxError);
319
320 tasks
321 }
322
323 #[instrument(skip_all)]
333 async fn get_paginated_genai_eval_workflow_records(
334 pool: &Pool<Postgres>,
335 params: &GenAIEvalRecordPaginationRequest,
336 entity_id: &i32,
337 ) -> Result<GenAIEvalWorkflowPaginationResponse, SqlError> {
338 let query = Queries::GetPaginatedGenAIEvalWorkflow.get_query();
339 let limit = params.limit.unwrap_or(50);
340 let direction = params.direction.as_deref().unwrap_or("next");
341
342 let mut items: Vec<GenAIEvalWorkflowResult> = sqlx::query_as(query)
343 .bind(entity_id)
344 .bind(params.cursor_created_at)
345 .bind(direction)
346 .bind(params.cursor_id)
347 .bind(limit)
348 .bind(params.start_datetime)
349 .bind(params.end_datetime)
350 .fetch_all(pool)
351 .await
352 .map_err(SqlError::SqlxError)?;
353
354 let has_more = items.len() > limit as usize;
355
356 if has_more {
357 items.pop();
358 }
359
360 let (has_next, next_cursor, has_previous, previous_cursor) = match direction {
361 "previous" => {
362 items.reverse();
363
364 let previous_cursor = if has_more {
365 items.first().map(|first| RecordCursor {
366 created_at: first.created_at,
367 id: first.id,
368 })
369 } else {
370 None
371 };
372
373 let next_cursor = items.last().map(|last| RecordCursor {
374 created_at: last.created_at,
375 id: last.id,
376 });
377
378 (
379 params.cursor_created_at.is_some(),
380 next_cursor,
381 has_more,
382 previous_cursor,
383 )
384 }
385 _ => {
386 let next_cursor = if has_more {
388 items.last().map(|last| RecordCursor {
389 created_at: last.created_at,
390 id: last.id,
391 })
392 } else {
393 None
394 };
395
396 let previous_cursor = items.first().map(|first| RecordCursor {
397 created_at: first.created_at,
398 id: first.id,
399 });
400
401 (
402 has_more,
403 next_cursor,
404 params.cursor_created_at.is_some(),
405 previous_cursor,
406 )
407 }
408 };
409
410 let public_items = items
411 .into_iter()
412 .map(|mut r| {
413 r.mask_sensitive_data();
414 r
415 })
416 .collect();
417
418 Ok(GenAIEvalWorkflowPaginationResponse {
419 items: public_items,
420 has_next,
421 next_cursor,
422 has_previous,
423 previous_cursor,
424 })
425 }
426
427 #[instrument(skip_all)]
436 async fn get_genai_task_values(
437 pool: &Pool<Postgres>,
438 limit_datetime: &DateTime<Utc>,
439 metrics: &[String],
440 entity_id: &i32,
441 ) -> Result<HashMap<String, f64>, SqlError> {
442 let query = Queries::GetGenAITaskValues.get_query();
443
444 let records = sqlx::query(query)
445 .bind(limit_datetime)
446 .bind(entity_id)
447 .bind(metrics)
448 .fetch_all(pool)
449 .await
450 .map_err(SqlError::SqlxError)?;
451
452 let metric_map = records
453 .into_iter()
454 .map(|row| {
455 let metric = row.get("metric");
456 let value = row.get("value");
457 (metric, value)
458 })
459 .collect();
460
461 Ok(metric_map)
462 }
463
464 #[instrument(skip_all)]
472 async fn get_genai_workflow_value(
473 pool: &Pool<Postgres>,
474 limit_datetime: &DateTime<Utc>,
475 entity_id: &i32,
476 ) -> Result<Option<f64>, SqlError> {
477 let query = Queries::GetGenAIWorkflowValues.get_query();
478
479 let records = sqlx::query(query)
480 .bind(limit_datetime)
481 .bind(entity_id)
482 .fetch_optional(pool)
483 .await
484 .inspect_err(|e| {
485 error!("Error fetching GenAI workflow values: {:?}", e);
486 })?;
487
488 Ok(records.and_then(|r| r.try_get("value").ok()))
489 }
490
491 #[instrument(skip_all)]
502 async fn get_binned_workflow_records(
503 pool: &Pool<Postgres>,
504 params: &DriftRequest,
505 start_dt: DateTime<Utc>,
506 end_dt: DateTime<Utc>,
507 entity_id: &i32,
508 ) -> Result<BinnedMetrics, SqlError> {
509 let minutes = end_dt.signed_duration_since(start_dt).num_minutes() as f64;
510 let bin = minutes / params.max_data_points as f64;
511
512 let query = Queries::GetGenAIWorkflowBinnedMetrics.get_query();
513
514 let records: Vec<BinnedMetricWrapper> = sqlx::query_as(query)
515 .bind(bin)
516 .bind(start_dt)
517 .bind(end_dt)
518 .bind(entity_id)
519 .fetch_all(pool)
520 .await
521 .map_err(SqlError::SqlxError)?;
522
523 Ok(BinnedMetrics::from_vec(
524 records.into_iter().map(|wrapper| wrapper.0).collect(),
525 ))
526 }
527
528 #[instrument(skip_all)]
539 async fn get_binned_task_records(
540 pool: &Pool<Postgres>,
541 params: &DriftRequest,
542 start_dt: DateTime<Utc>,
543 end_dt: DateTime<Utc>,
544 entity_id: &i32,
545 ) -> Result<BinnedMetrics, SqlError> {
546 let minutes = end_dt.signed_duration_since(start_dt).num_minutes() as f64;
547 let bin = minutes / params.max_data_points as f64;
548
549 let query = Queries::GetGenAITaskBinnedMetrics.get_query();
550
551 let records: Vec<BinnedMetricWrapper> = sqlx::query_as(query)
552 .bind(bin)
553 .bind(start_dt)
554 .bind(end_dt)
555 .bind(entity_id)
556 .fetch_all(pool)
557 .await
558 .map_err(SqlError::SqlxError)?;
559
560 Ok(BinnedMetrics::from_vec(
561 records.into_iter().map(|wrapper| wrapper.0).collect(),
562 ))
563 }
564
565 fn merge_feature_results(
567 results: BinnedMetrics,
568 map: &mut BinnedMetrics,
569 ) -> Result<(), SqlError> {
570 for (name, metric) in results.metrics {
571 let metric_clone = metric.clone();
572 map.metrics
573 .entry(name)
574 .and_modify(|existing| {
575 existing.created_at.extend(metric_clone.created_at);
576 existing.stats.extend(metric_clone.stats);
577 })
578 .or_insert(metric);
579 }
580
581 Ok(())
582 }
583
584 #[instrument(skip_all)]
598 async fn get_archived_task_records(
599 params: &DriftRequest,
600 begin: DateTime<Utc>,
601 end: DateTime<Utc>,
602 minutes: i32,
603 storage_settings: &ObjectStorageSettings,
604 entity_id: &i32,
605 ) -> Result<BinnedMetrics, SqlError> {
606 debug!("Getting archived GenAI metrics for params: {:?}", params);
607 let path = format!("{}/{}", params.uid, RecordType::GenAITask);
608 let bin = minutes as f64 / params.max_data_points as f64;
609 let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::GenAITask)?
610 .get_binned_metrics(&path, &bin, &begin, &end, entity_id)
611 .await
612 .inspect_err(|e| {
613 error!("Failed to get archived GenAI metrics: {:?}", e);
614 })?;
615
616 Ok(BinnedMetricsExtractor::dataframe_to_binned_metrics(archived_df).await?)
617 }
618
619 #[instrument(skip_all)]
633 async fn get_archived_workflow_records(
634 params: &DriftRequest,
635 begin: DateTime<Utc>,
636 end: DateTime<Utc>,
637 minutes: i32,
638 storage_settings: &ObjectStorageSettings,
639 entity_id: &i32,
640 ) -> Result<BinnedMetrics, SqlError> {
641 debug!("Getting archived GenAI metrics for params: {:?}", params);
642 let path = format!("{}/{}", params.uid, RecordType::GenAIWorkflow);
643 let bin = minutes as f64 / params.max_data_points as f64;
644 let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::GenAIWorkflow)?
645 .get_binned_metrics(&path, &bin, &begin, &end, entity_id)
646 .await
647 .inspect_err(|e| {
648 error!("Failed to get archived GenAI metrics: {:?}", e);
649 })?;
650
651 Ok(BinnedMetricsExtractor::dataframe_to_binned_metrics(archived_df).await?)
652 }
653
654 #[instrument(skip_all)]
664 async fn get_binned_genai_task_values(
665 pool: &Pool<Postgres>,
666 params: &DriftRequest,
667 retention_period: &i32,
668 storage_settings: &ObjectStorageSettings,
669 entity_id: &i32,
670 ) -> Result<BinnedMetrics, SqlError> {
671 debug!("Getting binned task drift records for {:?}", params);
672
673 if !params.has_custom_interval() {
674 debug!("No custom interval provided, using default");
675 let (start_dt, end_dt) = params.time_interval.to_begin_end_times()?;
676 return Self::get_binned_task_records(pool, params, start_dt, end_dt, entity_id).await;
677 }
678
679 debug!("Custom interval provided, using custom interval");
680 let interval = params.clone().to_custom_interval().unwrap();
681 let timestamps = split_custom_interval(interval.begin, interval.end, retention_period)?;
682 let mut custom_metric_map = BinnedMetrics::default();
683
684 if let Some((active_begin, active_end)) = timestamps.active_range {
686 let current_results =
687 Self::get_binned_task_records(pool, params, active_begin, active_end, entity_id)
688 .await?;
689 Self::merge_feature_results(current_results, &mut custom_metric_map)?;
690 }
691
692 if let Some((archive_begin, archive_end)) = timestamps.archived_range {
694 if let Some(archived_minutes) = timestamps.archived_minutes {
695 let archived_results = Self::get_archived_task_records(
696 params,
697 archive_begin,
698 archive_end,
699 archived_minutes,
700 storage_settings,
701 entity_id,
702 )
703 .await?;
704 Self::merge_feature_results(archived_results, &mut custom_metric_map)?;
705 }
706 }
707
708 Ok(custom_metric_map)
709 }
710
711 #[instrument(skip_all)]
712 async fn get_binned_genai_workflow_values(
713 pool: &Pool<Postgres>,
714 params: &DriftRequest,
715 retention_period: &i32,
716 storage_settings: &ObjectStorageSettings,
717 entity_id: &i32,
718 ) -> Result<BinnedMetrics, SqlError> {
719 debug!("Getting binned workflow drift records for {:?}", params);
720
721 if !params.has_custom_interval() {
722 debug!("No custom interval provided, using default");
723 let (start_dt, end_dt) = params.time_interval.to_begin_end_times()?;
724 return Self::get_binned_workflow_records(pool, params, start_dt, end_dt, entity_id)
725 .await;
726 }
727
728 debug!("Custom interval provided, using custom interval");
729 let interval = params.clone().to_custom_interval().unwrap();
730 let timestamps = split_custom_interval(interval.begin, interval.end, retention_period)?;
731 let mut custom_metric_map = BinnedMetrics::default();
732
733 if let Some((active_begin, active_end)) = timestamps.active_range {
735 let current_results = Self::get_binned_workflow_records(
736 pool,
737 params,
738 active_begin,
739 active_end,
740 entity_id,
741 )
742 .await?;
743 Self::merge_feature_results(current_results, &mut custom_metric_map)?;
744 }
745
746 if let Some((archive_begin, archive_end)) = timestamps.archived_range {
748 if let Some(archived_minutes) = timestamps.archived_minutes {
749 let archived_results = Self::get_archived_workflow_records(
750 params,
751 archive_begin,
752 archive_end,
753 archived_minutes,
754 storage_settings,
755 entity_id,
756 )
757 .await?;
758 Self::merge_feature_results(archived_results, &mut custom_metric_map)?;
759 }
760 }
761
762 debug!(
763 "Custom metric map length: {:?}",
764 custom_metric_map.metrics.len()
765 );
766
767 Ok(custom_metric_map)
768 }
769
770 async fn get_pending_genai_eval_record(
772 pool: &Pool<Postgres>,
773 ) -> Result<Option<GenAIEvalRecord>, SqlError> {
774 let query = Queries::GetPendingGenAIEvalTask.get_query();
775 let result: Option<GenAIEvalRecord> = sqlx::query_as(query)
776 .fetch_optional(pool)
777 .await
778 .map_err(SqlError::SqlxError)?;
779
780 debug!("Fetched pending GenAI drift record: {:?}", result);
781
782 Ok(result)
783 }
784
785 #[instrument(skip_all)]
786 async fn update_genai_eval_record_status(
787 pool: &Pool<Postgres>,
788 record: &GenAIEvalRecord,
789 status: Status,
790 workflow_duration: &i64,
791 ) -> Result<(), SqlError> {
792 let query = Queries::UpdateGenAIEvalTask.get_query();
793 let _query_result = sqlx::query(query)
794 .bind(status.as_str())
795 .bind(workflow_duration)
796 .bind(&record.uid)
797 .execute(pool)
798 .await
799 .inspect_err(|e| {
800 error!("Failed to update GenAI drift record status: {:?}", e);
801 })?;
802
803 Ok(())
804 }
805
806 #[instrument(skip_all)]
807 async fn reschedule_genai_eval_record(
808 pool: &Pool<Postgres>,
809 uid: &str,
810 delay: Duration,
811 ) -> Result<(), SqlError> {
812 let scheduled_at = Utc::now() + delay;
813
814 let query = Queries::RescheduleGenAIEvalRecord.get_query();
815 sqlx::query(query)
816 .bind(scheduled_at)
817 .bind(uid)
818 .execute(pool)
819 .await
820 .inspect_err(|e| {
821 error!("Failed to reschedule GenAI eval record: {:?}", e);
822 })?;
823
824 Ok(())
825 }
826}