1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3use crate::sql::schema::LLMRecordWrapper;
4use crate::sql::schema::{BinnedMetricWrapper, LLMDriftServerSQLRecord};
5use crate::sql::utils::split_custom_interval;
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use itertools::multiunzip;
9use scouter_dataframe::parquet::BinnedMetricsExtractor;
10use scouter_dataframe::parquet::ParquetDataFrame;
11use scouter_settings::ObjectStorageSettings;
12use scouter_types::contracts::{DriftRequest, ServiceInfo};
13use scouter_types::LLMRecord;
14use scouter_types::{
15 llm::{PaginationCursor, PaginationRequest, PaginationResponse},
16 BinnedMetrics, LLMDriftServerRecord, RecordType,
17};
18use scouter_types::{LLMMetricRecord, Status};
19use sqlx::types::Json;
20use sqlx::{postgres::PgQueryResult, Pool, Postgres, Row};
21use std::collections::HashMap;
22use tracing::error;
23use tracing::{debug, instrument};
24
25#[async_trait]
26pub trait LLMDriftSqlLogic {
27 async fn insert_llm_drift_record(
34 pool: &Pool<Postgres>,
35 record: &LLMDriftServerRecord,
36 ) -> Result<PgQueryResult, SqlError> {
37 let query = Queries::InsertLLMDriftRecord.get_query();
38
39 sqlx::query(&query.sql)
40 .bind(record.created_at)
41 .bind(&record.space)
42 .bind(&record.name)
43 .bind(&record.version)
44 .bind(&record.context)
45 .bind(Json(&record.prompt))
46 .execute(pool)
47 .await
48 .map_err(SqlError::SqlxError)
49 }
50
51 async fn insert_llm_metric_values_batch(
54 pool: &Pool<Postgres>,
55 records: &[LLMMetricRecord],
56 ) -> Result<PgQueryResult, SqlError> {
57 if records.is_empty() {
58 return Err(SqlError::EmptyBatchError);
59 }
60
61 let query = Queries::InsertLLMMetricValuesBatch.get_query();
62
63 let (created_ats, record_uids, names, spaces, versions, metrics, values): (
64 Vec<DateTime<Utc>>,
65 Vec<&str>,
66 Vec<&str>,
67 Vec<&str>,
68 Vec<&str>,
69 Vec<&str>,
70 Vec<f64>,
71 ) = multiunzip(records.iter().map(|r| {
72 (
73 r.created_at,
74 r.record_uid.as_str(),
75 r.name.as_str(),
76 r.space.as_str(),
77 r.version.as_str(),
78 r.metric.as_str(),
79 r.value,
80 )
81 }));
82
83 sqlx::query(&query.sql)
84 .bind(created_ats)
85 .bind(record_uids)
86 .bind(spaces)
87 .bind(names)
88 .bind(versions)
89 .bind(metrics)
90 .bind(values)
91 .execute(pool)
92 .await
93 .map_err(SqlError::SqlxError)
94 }
95
96 async fn get_llm_drift_records(
97 pool: &Pool<Postgres>,
98 service_info: &ServiceInfo,
99 limit_datetime: Option<&DateTime<Utc>>,
100 status: Option<Status>,
101 ) -> Result<Vec<LLMDriftServerRecord>, SqlError> {
102 let mut query_string = Queries::GetLLMDriftRecords.get_query().sql;
103
104 let mut bind_count = 3;
105
106 if limit_datetime.is_some() {
107 bind_count += 1;
108 query_string.push_str(&format!(" AND created_at > ${bind_count}"));
109 }
110
111 let status_value = status.as_ref().and_then(|s| s.as_str());
112 if status_value.is_some() {
113 bind_count += 1;
114 query_string.push_str(&format!(" AND status = ${bind_count}"));
115 }
116
117 let mut query = sqlx::query_as::<_, LLMDriftServerSQLRecord>(&query_string)
118 .bind(&service_info.space)
119 .bind(&service_info.name)
120 .bind(&service_info.version);
121
122 if let Some(datetime) = limit_datetime {
123 query = query.bind(datetime);
124 }
125 if let Some(status) = status_value {
127 query = query.bind(status);
128 }
129
130 let records = query.fetch_all(pool).await.map_err(SqlError::SqlxError)?;
131
132 Ok(records
133 .into_iter()
134 .map(LLMDriftServerRecord::from)
135 .collect())
136 }
137
138 #[instrument(skip_all)]
148 async fn get_llm_drift_records_pagination(
149 pool: &Pool<Postgres>,
150 service_info: &ServiceInfo,
151 status: Option<Status>,
152 pagination: PaginationRequest,
153 ) -> Result<PaginationResponse<LLMDriftServerRecord>, SqlError> {
154 let limit = pagination.limit.clamp(1, 100); let query_limit = limit + 1;
156
157 let mut sql = Queries::GetLLMDriftRecords.get_query().sql;
159 let mut bind_count = 3;
160
161 if pagination.cursor.is_some() {
164 bind_count += 1;
165 sql.push_str(&format!(" AND id < ${bind_count}"));
166 }
167
168 let status_value = status.as_ref().and_then(|s| s.as_str());
170 if status_value.is_some() {
171 bind_count += 1;
172 sql.push_str(&format!(" AND status = ${bind_count}"));
173 }
174
175 sql.push_str(&format!(" ORDER BY id DESC LIMIT ${}", bind_count + 1));
176
177 let mut query = sqlx::query_as::<_, LLMDriftServerSQLRecord>(&sql)
178 .bind(&service_info.space)
179 .bind(&service_info.name)
180 .bind(&service_info.version);
181
182 if let Some(cursor) = &pagination.cursor {
184 query = query.bind(cursor.id);
185 }
186
187 if let Some(status) = status_value {
189 query = query.bind(status);
190 }
191
192 query = query.bind(query_limit);
194
195 let mut records = query.fetch_all(pool).await.map_err(SqlError::SqlxError)?;
196
197 let has_more = records.len() > limit as usize;
199 if has_more {
200 records.pop(); }
202
203 let next_cursor = if has_more && !records.is_empty() {
204 let last_record = records.last().unwrap();
205 Some(PaginationCursor { id: last_record.id })
206 } else {
207 None
208 };
209
210 let items = records
211 .into_iter()
212 .map(LLMDriftServerRecord::from)
213 .collect();
214
215 Ok(PaginationResponse {
216 items,
217 next_cursor,
218 has_more,
219 })
220 }
221
222 async fn get_llm_metric_values(
223 pool: &Pool<Postgres>,
224 service_info: &ServiceInfo,
225 limit_datetime: &DateTime<Utc>,
226 metrics: &[String],
227 ) -> Result<HashMap<String, f64>, SqlError> {
228 let query = Queries::GetLLMMetricValues.get_query();
229
230 let records = sqlx::query(&query.sql)
231 .bind(limit_datetime)
232 .bind(&service_info.space)
233 .bind(&service_info.name)
234 .bind(&service_info.version)
235 .bind(metrics)
236 .fetch_all(pool)
237 .await
238 .map_err(SqlError::SqlxError)?;
239
240 let metric_map = records
241 .into_iter()
242 .map(|row| {
243 let metric = row.get("metric");
244 let value = row.get("value");
245 (metric, value)
246 })
247 .collect();
248
249 Ok(metric_map)
250 }
251
252 #[instrument(skip_all)]
262 async fn get_records(
263 pool: &Pool<Postgres>,
264 params: &DriftRequest,
265 minutes: i32,
266 ) -> Result<BinnedMetrics, SqlError> {
267 let bin = params.time_interval.to_minutes() as f64 / params.max_data_points as f64;
268
269 let query = Queries::GetBinnedMetrics.get_query();
270
271 let records: Vec<BinnedMetricWrapper> = sqlx::query_as(&query.sql)
272 .bind(bin)
273 .bind(minutes)
274 .bind(¶ms.space)
275 .bind(¶ms.name)
276 .bind(¶ms.version)
277 .fetch_all(pool)
278 .await
279 .map_err(SqlError::SqlxError)?;
280
281 Ok(BinnedMetrics::from_vec(
282 records.into_iter().map(|wrapper| wrapper.0).collect(),
283 ))
284 }
285
286 fn merge_feature_results(
288 results: BinnedMetrics,
289 map: &mut BinnedMetrics,
290 ) -> Result<(), SqlError> {
291 for (name, metric) in results.metrics {
292 let metric_clone = metric.clone();
293 map.metrics
294 .entry(name)
295 .and_modify(|existing| {
296 existing.created_at.extend(metric_clone.created_at);
297 existing.stats.extend(metric_clone.stats);
298 })
299 .or_insert(metric);
300 }
301
302 Ok(())
303 }
304
305 #[instrument(skip_all)]
317 async fn get_archived_metric_records(
318 params: &DriftRequest,
319 begin: DateTime<Utc>,
320 end: DateTime<Utc>,
321 minutes: i32,
322 storage_settings: &ObjectStorageSettings,
323 ) -> Result<BinnedMetrics, SqlError> {
324 debug!("Getting archived LLM metrics for params: {:?}", params);
325 let path = format!(
326 "{}/{}/{}/llm_metric",
327 params.space, params.name, params.version
328 );
329 let bin = minutes as f64 / params.max_data_points as f64;
330 let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::LLMMetric)?
331 .get_binned_metrics(
332 &path,
333 &bin,
334 &begin,
335 &end,
336 ¶ms.space,
337 ¶ms.name,
338 ¶ms.version,
339 )
340 .await
341 .inspect_err(|e| {
342 error!("Failed to get archived LLM metrics: {:?}", e);
343 })?;
344
345 Ok(BinnedMetricsExtractor::dataframe_to_binned_metrics(archived_df).await?)
346 }
347
348 #[instrument(skip_all)]
358 async fn get_binned_llm_metric_values(
359 pool: &Pool<Postgres>,
360 params: &DriftRequest,
361 retention_period: &i32,
362 storage_settings: &ObjectStorageSettings,
363 ) -> Result<BinnedMetrics, SqlError> {
364 debug!("Getting binned Custom drift records for {:?}", params);
365
366 if !params.has_custom_interval() {
367 debug!("No custom interval provided, using default");
368 let minutes = params.time_interval.to_minutes();
369 return Self::get_records(pool, params, minutes).await;
370 }
371
372 debug!("Custom interval provided, using custom interval");
373 let interval = params.clone().to_custom_interval().unwrap();
374 let timestamps = split_custom_interval(interval.start, interval.end, retention_period)?;
375 let mut custom_metric_map = BinnedMetrics::default();
376
377 if let Some(minutes) = timestamps.current_minutes {
379 let current_results = Self::get_records(pool, params, minutes).await?;
380 Self::merge_feature_results(current_results, &mut custom_metric_map)?;
381 }
382
383 if let Some((archive_begin, archive_end)) = timestamps.archived_range {
385 if let Some(archived_minutes) = timestamps.archived_minutes {
386 let archived_results = Self::get_archived_metric_records(
387 params,
388 archive_begin,
389 archive_end,
390 archived_minutes,
391 storage_settings,
392 )
393 .await?;
394 Self::merge_feature_results(archived_results, &mut custom_metric_map)?;
395 }
396 }
397
398 Ok(custom_metric_map)
399 }
400
401 async fn get_pending_llm_drift_record(
403 pool: &Pool<Postgres>,
404 ) -> Result<Option<LLMRecord>, SqlError> {
405 let query = Queries::GetPendingLLMDriftTask.get_query();
406 let result: Option<LLMRecordWrapper> = sqlx::query_as(&query.sql)
407 .fetch_optional(pool)
408 .await
409 .map_err(SqlError::SqlxError)?;
410
411 Ok(result.map(|wrapper| wrapper.0))
412 }
413
414 #[instrument(skip_all)]
415 async fn update_llm_drift_record_status(
416 pool: &Pool<Postgres>,
417 record: &LLMRecord,
418 status: Status,
419 workflow_duration: Option<i32>, ) -> Result<(), SqlError> {
421 let query = Queries::UpdateLLMDriftTask.get_query();
422
423 let _query_result = sqlx::query(&query.sql)
424 .bind(status.as_str())
425 .bind(record.score.clone())
426 .bind(workflow_duration)
427 .bind(&record.uid)
428 .execute(pool)
429 .await
430 .inspect_err(|e| {
431 error!("Failed to update LLM drift record status: {:?}", e);
432 })?;
433
434 Ok(())
435 }
436}