1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3use crate::sql::schema::BinnedMetricWrapper;
4use crate::sql::utils::split_custom_interval;
5use async_trait::async_trait;
6use chrono::{DateTime, Utc};
7use scouter_dataframe::parquet::BinnedMetricsExtractor;
8use scouter_dataframe::parquet::ParquetDataFrame;
9use scouter_settings::ObjectStorageSettings;
10use scouter_types::contracts::DriftRequest;
11use scouter_types::BoxedGenAIEvalRecord;
12use scouter_types::GenAIEvalRecord;
13use scouter_types::GenAIEvalTaskResult;
14use scouter_types::GenAIEvalWorkflowPaginationResponse;
15use scouter_types::GenAIEvalWorkflowResult;
16use scouter_types::Status;
17use scouter_types::{
18 BinnedMetrics, GenAIEvalRecordPaginationRequest, GenAIEvalRecordPaginationResponse,
19 RecordCursor, RecordType,
20};
21use sqlx::types::Json;
22use sqlx::{postgres::PgQueryResult, Pool, Postgres, Row};
23use std::collections::HashMap;
24use tracing::error;
25use tracing::{debug, instrument};
26
27#[async_trait]
28pub trait GenAIDriftSqlLogic {
29 async fn insert_genai_eval_record(
36 pool: &Pool<Postgres>,
37 record: BoxedGenAIEvalRecord,
38 entity_id: &i32,
39 ) -> Result<PgQueryResult, SqlError> {
40 let query = Queries::InsertGenAIEvalRecord.get_query();
41
42 sqlx::query(query)
43 .bind(record.record.uid)
44 .bind(record.record.created_at)
45 .bind(entity_id)
46 .bind(Json(record.record.context))
47 .bind(&record.record.record_id)
48 .bind(&record.record.session_id)
49 .execute(pool)
50 .await
51 .map_err(SqlError::SqlxError)
52 }
53
54 async fn insert_genai_eval_workflow_record(
61 pool: &Pool<Postgres>,
62 record: &GenAIEvalWorkflowResult,
63 entity_id: &i32,
64 ) -> Result<PgQueryResult, SqlError> {
65 let query = Queries::InsertGenAIWorkflowResult.get_query();
66
67 sqlx::query(query)
68 .bind(record.created_at)
69 .bind(record.record_uid.as_str())
70 .bind(entity_id)
71 .bind(record.total_tasks)
72 .bind(record.passed_tasks)
73 .bind(record.failed_tasks)
74 .bind(record.pass_rate)
75 .bind(record.duration_ms)
76 .bind(Json(&record.execution_plan))
77 .execute(pool)
78 .await
79 .map_err(SqlError::SqlxError)
80 }
81
82 async fn insert_eval_task_results_batch(
85 pool: &Pool<Postgres>,
86 records: &[GenAIEvalTaskResult], entity_id: &i32,
88 ) -> Result<sqlx::postgres::PgQueryResult, SqlError> {
89 if records.is_empty() {
90 return Err(SqlError::EmptyBatchError);
91 }
92
93 let n = records.len();
94
95 let mut created_ats = Vec::with_capacity(n);
97 let mut start_times = Vec::with_capacity(n);
98 let mut end_times = Vec::with_capacity(n);
99 let mut record_uids = Vec::with_capacity(n);
100 let mut entity_ids = Vec::with_capacity(n);
101 let mut task_ids = Vec::with_capacity(n);
102 let mut task_types = Vec::with_capacity(n);
103 let mut passed_flags = Vec::with_capacity(n);
104 let mut values = Vec::with_capacity(n);
105 let mut field_paths = Vec::with_capacity(n);
106 let mut operators = Vec::with_capacity(n);
107 let mut expected_jsons = Vec::with_capacity(n);
108 let mut actual_jsons = Vec::with_capacity(n);
109 let mut messages = Vec::with_capacity(n);
110 let mut condition = Vec::with_capacity(n);
111 let mut stage = Vec::with_capacity(n);
112
113 for r in records {
114 created_ats.push(r.created_at);
115 start_times.push(r.start_time);
116 end_times.push(r.end_time);
117 record_uids.push(&r.record_uid);
118 entity_ids.push(entity_id);
119 task_ids.push(&r.task_id);
120 task_types.push(r.task_type.as_str());
121 passed_flags.push(r.passed);
122 values.push(r.value);
123 field_paths.push(r.field_path.as_deref());
124 operators.push(r.operator.as_str());
125 expected_jsons.push(Json(&r.expected));
126 actual_jsons.push(Json(&r.actual));
127 messages.push(&r.message);
128 condition.push(r.condition);
129 stage.push(r.stage);
130 }
131
132 let query = Queries::InsertGenAITaskResultsBatch.get_query();
133
134 sqlx::query(query)
135 .bind(&created_ats)
136 .bind(&start_times)
137 .bind(&end_times)
138 .bind(&record_uids)
139 .bind(&entity_ids)
140 .bind(&task_ids)
141 .bind(&task_types)
142 .bind(&passed_flags)
143 .bind(&values)
144 .bind(&field_paths)
145 .bind(&operators)
146 .bind(&expected_jsons)
147 .bind(&actual_jsons)
148 .bind(&messages)
149 .bind(&condition)
150 .bind(&stage)
151 .execute(pool)
152 .await
153 .map_err(SqlError::SqlxError)
154 }
155
156 async fn get_genai_eval_records(
157 pool: &Pool<Postgres>,
158 limit_datetime: Option<&DateTime<Utc>>,
159 status: Option<Status>,
160 entity_id: &i32,
161 ) -> Result<Vec<GenAIEvalRecord>, SqlError> {
162 let mut query_string = Queries::GetGenAIEvalRecords.get_query().to_string();
163 let mut bind_count = 1;
164
165 if limit_datetime.is_some() {
166 bind_count += 1;
167 query_string.push_str(&format!(" AND created_at > ${bind_count}"));
168 }
169
170 let status_value = status.as_ref().and_then(|s| s.as_str());
171 if status_value.is_some() {
172 bind_count += 1;
173 query_string.push_str(&format!(" AND status = ${bind_count}"));
174 }
175
176 let mut query = sqlx::query_as::<_, GenAIEvalRecord>(&query_string).bind(entity_id);
177
178 if let Some(datetime) = limit_datetime {
179 query = query.bind(datetime);
180 }
181 if let Some(status) = status_value {
183 query = query.bind(status);
184 }
185
186 let records = query.fetch_all(pool).await.map_err(SqlError::SqlxError)?;
187
188 Ok(records
189 .into_iter()
190 .map(|mut r| {
191 r.mask_sensitive_data();
192 r
193 })
194 .collect())
195 }
196
197 #[instrument(skip_all)]
207 async fn get_paginated_genai_eval_records(
208 pool: &Pool<Postgres>,
209 params: &GenAIEvalRecordPaginationRequest,
210 entity_id: &i32,
211 ) -> Result<GenAIEvalRecordPaginationResponse, SqlError> {
212 let query = Queries::GetPaginatedGenAIEvalRecords.get_query();
213 let limit = params.limit.unwrap_or(50);
214 let direction = params.direction.as_deref().unwrap_or("next");
215
216 let mut items: Vec<GenAIEvalRecord> = sqlx::query_as(query)
217 .bind(entity_id)
218 .bind(params.status.as_ref().and_then(|s| s.as_str()))
219 .bind(params.cursor_created_at)
220 .bind(direction)
221 .bind(params.cursor_id)
222 .bind(limit)
223 .bind(params.start_datetime)
224 .bind(params.end_datetime)
225 .fetch_all(pool)
226 .await
227 .map_err(SqlError::SqlxError)?;
228
229 let has_more = items.len() > limit as usize;
230
231 if has_more {
232 items.pop();
233 }
234
235 let (has_next, next_cursor, has_previous, previous_cursor) = match direction {
236 "previous" => {
237 items.reverse();
238
239 let previous_cursor = if has_more {
240 items.first().map(|first| RecordCursor {
241 created_at: first.created_at,
242 id: first.id,
243 })
244 } else {
245 None
246 };
247
248 let next_cursor = items.last().map(|last| RecordCursor {
249 created_at: last.created_at,
250 id: last.id,
251 });
252
253 (
254 params.cursor_created_at.is_some(),
255 next_cursor,
256 has_more,
257 previous_cursor,
258 )
259 }
260 _ => {
261 let next_cursor = if has_more {
263 items.last().map(|last| RecordCursor {
264 created_at: last.created_at,
265 id: last.id,
266 })
267 } else {
268 None
269 };
270
271 let previous_cursor = items.first().map(|first| RecordCursor {
272 created_at: first.created_at,
273 id: first.id,
274 });
275
276 (
277 has_more,
278 next_cursor,
279 params.cursor_created_at.is_some(),
280 previous_cursor,
281 )
282 }
283 };
284
285 let public_items = items
286 .into_iter()
287 .map(|mut r| {
288 r.mask_sensitive_data();
289 r
290 })
291 .collect();
292
293 Ok(GenAIEvalRecordPaginationResponse {
294 items: public_items,
295 has_next,
296 next_cursor,
297 has_previous,
298 previous_cursor,
299 })
300 }
301
302 async fn get_genai_eval_task(
309 pool: &Pool<Postgres>,
310 record_uid: &str,
311 ) -> Result<Vec<GenAIEvalTaskResult>, SqlError> {
312 let query = Queries::GetGenAIEvalTasks.get_query();
313 let tasks: Result<Vec<GenAIEvalTaskResult>, SqlError> = sqlx::query_as(query)
314 .bind(record_uid)
315 .fetch_all(pool)
316 .await
317 .map_err(SqlError::SqlxError);
318
319 tasks
320 }
321
322 #[instrument(skip_all)]
332 async fn get_paginated_genai_eval_workflow_records(
333 pool: &Pool<Postgres>,
334 params: &GenAIEvalRecordPaginationRequest,
335 entity_id: &i32,
336 ) -> Result<GenAIEvalWorkflowPaginationResponse, SqlError> {
337 let query = Queries::GetPaginatedGenAIEvalWorkflow.get_query();
338 let limit = params.limit.unwrap_or(50);
339 let direction = params.direction.as_deref().unwrap_or("next");
340
341 let mut items: Vec<GenAIEvalWorkflowResult> = sqlx::query_as(query)
342 .bind(entity_id)
343 .bind(params.cursor_created_at)
344 .bind(direction)
345 .bind(params.cursor_id)
346 .bind(limit)
347 .bind(params.start_datetime)
348 .bind(params.end_datetime)
349 .fetch_all(pool)
350 .await
351 .map_err(SqlError::SqlxError)?;
352
353 let has_more = items.len() > limit as usize;
354
355 if has_more {
356 items.pop();
357 }
358
359 let (has_next, next_cursor, has_previous, previous_cursor) = match direction {
360 "previous" => {
361 items.reverse();
362
363 let previous_cursor = if has_more {
364 items.first().map(|first| RecordCursor {
365 created_at: first.created_at,
366 id: first.id,
367 })
368 } else {
369 None
370 };
371
372 let next_cursor = items.last().map(|last| RecordCursor {
373 created_at: last.created_at,
374 id: last.id,
375 });
376
377 (
378 params.cursor_created_at.is_some(),
379 next_cursor,
380 has_more,
381 previous_cursor,
382 )
383 }
384 _ => {
385 let next_cursor = if has_more {
387 items.last().map(|last| RecordCursor {
388 created_at: last.created_at,
389 id: last.id,
390 })
391 } else {
392 None
393 };
394
395 let previous_cursor = items.first().map(|first| RecordCursor {
396 created_at: first.created_at,
397 id: first.id,
398 });
399
400 (
401 has_more,
402 next_cursor,
403 params.cursor_created_at.is_some(),
404 previous_cursor,
405 )
406 }
407 };
408
409 let public_items = items
410 .into_iter()
411 .map(|mut r| {
412 r.mask_sensitive_data();
413 r
414 })
415 .collect();
416
417 Ok(GenAIEvalWorkflowPaginationResponse {
418 items: public_items,
419 has_next,
420 next_cursor,
421 has_previous,
422 previous_cursor,
423 })
424 }
425
426 #[instrument(skip_all)]
435 async fn get_genai_task_values(
436 pool: &Pool<Postgres>,
437 limit_datetime: &DateTime<Utc>,
438 metrics: &[String],
439 entity_id: &i32,
440 ) -> Result<HashMap<String, f64>, SqlError> {
441 let query = Queries::GetGenAITaskValues.get_query();
442
443 let records = sqlx::query(query)
444 .bind(limit_datetime)
445 .bind(entity_id)
446 .bind(metrics)
447 .fetch_all(pool)
448 .await
449 .map_err(SqlError::SqlxError)?;
450
451 let metric_map = records
452 .into_iter()
453 .map(|row| {
454 let metric = row.get("metric");
455 let value = row.get("value");
456 (metric, value)
457 })
458 .collect();
459
460 Ok(metric_map)
461 }
462
463 #[instrument(skip_all)]
471 async fn get_genai_workflow_value(
472 pool: &Pool<Postgres>,
473 limit_datetime: &DateTime<Utc>,
474 entity_id: &i32,
475 ) -> Result<Option<f64>, SqlError> {
476 let query = Queries::GetGenAIWorkflowValues.get_query();
477
478 let records = sqlx::query(query)
479 .bind(limit_datetime)
480 .bind(entity_id)
481 .fetch_optional(pool)
482 .await
483 .inspect_err(|e| {
484 error!("Error fetching GenAI workflow values: {:?}", e);
485 })?;
486
487 Ok(records.and_then(|r| r.try_get("value").ok()))
488 }
489
490 #[instrument(skip_all)]
501 async fn get_binned_workflow_records(
502 pool: &Pool<Postgres>,
503 params: &DriftRequest,
504 start_dt: DateTime<Utc>,
505 end_dt: DateTime<Utc>,
506 entity_id: &i32,
507 ) -> Result<BinnedMetrics, SqlError> {
508 let minutes = end_dt.signed_duration_since(start_dt).num_minutes() as f64;
509 let bin = minutes / params.max_data_points as f64;
510
511 let query = Queries::GetGenAIWorkflowBinnedMetrics.get_query();
512
513 let records: Vec<BinnedMetricWrapper> = sqlx::query_as(query)
514 .bind(bin)
515 .bind(start_dt)
516 .bind(end_dt)
517 .bind(entity_id)
518 .fetch_all(pool)
519 .await
520 .map_err(SqlError::SqlxError)?;
521
522 Ok(BinnedMetrics::from_vec(
523 records.into_iter().map(|wrapper| wrapper.0).collect(),
524 ))
525 }
526
527 #[instrument(skip_all)]
538 async fn get_binned_task_records(
539 pool: &Pool<Postgres>,
540 params: &DriftRequest,
541 start_dt: DateTime<Utc>,
542 end_dt: DateTime<Utc>,
543 entity_id: &i32,
544 ) -> Result<BinnedMetrics, SqlError> {
545 let minutes = end_dt.signed_duration_since(start_dt).num_minutes() as f64;
546 let bin = minutes / params.max_data_points as f64;
547
548 let query = Queries::GetGenAITaskBinnedMetrics.get_query();
549
550 let records: Vec<BinnedMetricWrapper> = sqlx::query_as(query)
551 .bind(bin)
552 .bind(start_dt)
553 .bind(end_dt)
554 .bind(entity_id)
555 .fetch_all(pool)
556 .await
557 .map_err(SqlError::SqlxError)?;
558
559 Ok(BinnedMetrics::from_vec(
560 records.into_iter().map(|wrapper| wrapper.0).collect(),
561 ))
562 }
563
564 fn merge_feature_results(
566 results: BinnedMetrics,
567 map: &mut BinnedMetrics,
568 ) -> Result<(), SqlError> {
569 for (name, metric) in results.metrics {
570 let metric_clone = metric.clone();
571 map.metrics
572 .entry(name)
573 .and_modify(|existing| {
574 existing.created_at.extend(metric_clone.created_at);
575 existing.stats.extend(metric_clone.stats);
576 })
577 .or_insert(metric);
578 }
579
580 Ok(())
581 }
582
583 #[instrument(skip_all)]
597 async fn get_archived_task_records(
598 params: &DriftRequest,
599 begin: DateTime<Utc>,
600 end: DateTime<Utc>,
601 minutes: i32,
602 storage_settings: &ObjectStorageSettings,
603 entity_id: &i32,
604 ) -> Result<BinnedMetrics, SqlError> {
605 debug!("Getting archived GenAI metrics for params: {:?}", params);
606 let path = format!("{}/{}", params.uid, RecordType::GenAITask);
607 let bin = minutes as f64 / params.max_data_points as f64;
608 let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::GenAITask)?
609 .get_binned_metrics(&path, &bin, &begin, &end, entity_id)
610 .await
611 .inspect_err(|e| {
612 error!("Failed to get archived GenAI metrics: {:?}", e);
613 })?;
614
615 Ok(BinnedMetricsExtractor::dataframe_to_binned_metrics(archived_df).await?)
616 }
617
618 #[instrument(skip_all)]
632 async fn get_archived_workflow_records(
633 params: &DriftRequest,
634 begin: DateTime<Utc>,
635 end: DateTime<Utc>,
636 minutes: i32,
637 storage_settings: &ObjectStorageSettings,
638 entity_id: &i32,
639 ) -> Result<BinnedMetrics, SqlError> {
640 debug!("Getting archived GenAI metrics for params: {:?}", params);
641 let path = format!("{}/{}", params.uid, RecordType::GenAIWorkflow);
642 let bin = minutes as f64 / params.max_data_points as f64;
643 let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::GenAIWorkflow)?
644 .get_binned_metrics(&path, &bin, &begin, &end, entity_id)
645 .await
646 .inspect_err(|e| {
647 error!("Failed to get archived GenAI metrics: {:?}", e);
648 })?;
649
650 Ok(BinnedMetricsExtractor::dataframe_to_binned_metrics(archived_df).await?)
651 }
652
653 #[instrument(skip_all)]
663 async fn get_binned_genai_task_values(
664 pool: &Pool<Postgres>,
665 params: &DriftRequest,
666 retention_period: &i32,
667 storage_settings: &ObjectStorageSettings,
668 entity_id: &i32,
669 ) -> Result<BinnedMetrics, SqlError> {
670 debug!("Getting binned task drift records for {:?}", params);
671
672 if !params.has_custom_interval() {
673 debug!("No custom interval provided, using default");
674 let (start_dt, end_dt) = params.time_interval.to_begin_end_times()?;
675 return Self::get_binned_task_records(pool, params, start_dt, end_dt, entity_id).await;
676 }
677
678 debug!("Custom interval provided, using custom interval");
679 let interval = params.clone().to_custom_interval().unwrap();
680 let timestamps = split_custom_interval(interval.begin, interval.end, retention_period)?;
681 let mut custom_metric_map = BinnedMetrics::default();
682
683 if let Some((active_begin, active_end)) = timestamps.active_range {
685 let current_results =
686 Self::get_binned_task_records(pool, params, active_begin, active_end, entity_id)
687 .await?;
688 Self::merge_feature_results(current_results, &mut custom_metric_map)?;
689 }
690
691 if let Some((archive_begin, archive_end)) = timestamps.archived_range {
693 if let Some(archived_minutes) = timestamps.archived_minutes {
694 let archived_results = Self::get_archived_task_records(
695 params,
696 archive_begin,
697 archive_end,
698 archived_minutes,
699 storage_settings,
700 entity_id,
701 )
702 .await?;
703 Self::merge_feature_results(archived_results, &mut custom_metric_map)?;
704 }
705 }
706
707 Ok(custom_metric_map)
708 }
709
710 #[instrument(skip_all)]
711 async fn get_binned_genai_workflow_values(
712 pool: &Pool<Postgres>,
713 params: &DriftRequest,
714 retention_period: &i32,
715 storage_settings: &ObjectStorageSettings,
716 entity_id: &i32,
717 ) -> Result<BinnedMetrics, SqlError> {
718 debug!("Getting binned workflow drift records for {:?}", params);
719
720 if !params.has_custom_interval() {
721 debug!("No custom interval provided, using default");
722 let (start_dt, end_dt) = params.time_interval.to_begin_end_times()?;
723 return Self::get_binned_workflow_records(pool, params, start_dt, end_dt, entity_id)
724 .await;
725 }
726
727 debug!("Custom interval provided, using custom interval");
728 let interval = params.clone().to_custom_interval().unwrap();
729 let timestamps = split_custom_interval(interval.begin, interval.end, retention_period)?;
730 let mut custom_metric_map = BinnedMetrics::default();
731
732 if let Some((active_begin, active_end)) = timestamps.active_range {
734 let current_results = Self::get_binned_workflow_records(
735 pool,
736 params,
737 active_begin,
738 active_end,
739 entity_id,
740 )
741 .await?;
742 Self::merge_feature_results(current_results, &mut custom_metric_map)?;
743 }
744
745 if let Some((archive_begin, archive_end)) = timestamps.archived_range {
747 if let Some(archived_minutes) = timestamps.archived_minutes {
748 let archived_results = Self::get_archived_workflow_records(
749 params,
750 archive_begin,
751 archive_end,
752 archived_minutes,
753 storage_settings,
754 entity_id,
755 )
756 .await?;
757 Self::merge_feature_results(archived_results, &mut custom_metric_map)?;
758 }
759 }
760
761 debug!(
762 "Custom metric map length: {:?}",
763 custom_metric_map.metrics.len()
764 );
765
766 Ok(custom_metric_map)
767 }
768
769 async fn get_pending_genai_eval_record(
771 pool: &Pool<Postgres>,
772 ) -> Result<Option<GenAIEvalRecord>, SqlError> {
773 let query = Queries::GetPendingGenAIEvalTask.get_query();
774 let result: Option<GenAIEvalRecord> = sqlx::query_as(query)
775 .fetch_optional(pool)
776 .await
777 .map_err(SqlError::SqlxError)?;
778
779 debug!("Fetched pending GenAI drift record: {:?}", result);
780
781 Ok(result)
782 }
783
784 #[instrument(skip_all)]
785 async fn update_genai_eval_record_status(
786 pool: &Pool<Postgres>,
787 record: &GenAIEvalRecord,
788 status: Status,
789 workflow_duration: &i64,
790 ) -> Result<(), SqlError> {
791 let query = Queries::UpdateGenAIEvalTask.get_query();
792 let _query_result = sqlx::query(query)
793 .bind(status.as_str())
794 .bind(workflow_duration)
795 .bind(&record.uid)
796 .execute(pool)
797 .await
798 .inspect_err(|e| {
799 error!("Failed to update GenAI drift record status: {:?}", e);
800 })?;
801
802 Ok(())
803 }
804}