scouter_sql/sql/traits/
llm.rs

1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3use crate::sql::schema::LLMRecordWrapper;
4use crate::sql::schema::{BinnedMetricWrapper, LLMDriftServerSQLRecord};
5use crate::sql::utils::split_custom_interval;
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use itertools::multiunzip;
9use scouter_dataframe::parquet::BinnedMetricsExtractor;
10use scouter_dataframe::parquet::ParquetDataFrame;
11use scouter_settings::ObjectStorageSettings;
12use scouter_types::contracts::{DriftRequest, ServiceInfo};
13use scouter_types::LLMRecord;
14use scouter_types::{
15    llm::{PaginationCursor, PaginationRequest, PaginationResponse},
16    BinnedMetrics, LLMDriftServerRecord, RecordType,
17};
18use scouter_types::{LLMMetricRecord, Status};
19use sqlx::types::Json;
20use sqlx::{postgres::PgQueryResult, Pool, Postgres, Row};
21use std::collections::HashMap;
22use tracing::error;
23use tracing::{debug, instrument};
24
25#[async_trait]
26pub trait LLMDriftSqlLogic {
27    /// Inserts an LLM drift record into the database.
28    /// # Arguments
29    /// * `pool` - The database connection pool
30    /// * `record` - The LLM drift record to insert
31    /// # Returns
32    /// * A result containing the query result or an error
33    async fn insert_llm_drift_record(
34        pool: &Pool<Postgres>,
35        record: &LLMDriftServerRecord,
36    ) -> Result<PgQueryResult, SqlError> {
37        let query = Queries::InsertLLMDriftRecord.get_query();
38
39        sqlx::query(&query.sql)
40            .bind(record.created_at)
41            .bind(&record.space)
42            .bind(&record.name)
43            .bind(&record.version)
44            .bind(&record.context)
45            .bind(Json(&record.prompt))
46            .execute(pool)
47            .await
48            .map_err(SqlError::SqlxError)
49    }
50
51    /// Inserts a batch of LLM metric values into the database.
52    /// This is the output from processing/evaluating the LLM drift records.
53    async fn insert_llm_metric_values_batch(
54        pool: &Pool<Postgres>,
55        records: &[LLMMetricRecord],
56    ) -> Result<PgQueryResult, SqlError> {
57        if records.is_empty() {
58            return Err(SqlError::EmptyBatchError);
59        }
60
61        let query = Queries::InsertLLMMetricValuesBatch.get_query();
62
63        let (created_ats, record_uids, names, spaces, versions, metrics, values): (
64            Vec<DateTime<Utc>>,
65            Vec<&str>,
66            Vec<&str>,
67            Vec<&str>,
68            Vec<&str>,
69            Vec<&str>,
70            Vec<f64>,
71        ) = multiunzip(records.iter().map(|r| {
72            (
73                r.created_at,
74                r.record_uid.as_str(),
75                r.name.as_str(),
76                r.space.as_str(),
77                r.version.as_str(),
78                r.metric.as_str(),
79                r.value,
80            )
81        }));
82
83        sqlx::query(&query.sql)
84            .bind(created_ats)
85            .bind(record_uids)
86            .bind(spaces)
87            .bind(names)
88            .bind(versions)
89            .bind(metrics)
90            .bind(values)
91            .execute(pool)
92            .await
93            .map_err(SqlError::SqlxError)
94    }
95
96    async fn get_llm_drift_records(
97        pool: &Pool<Postgres>,
98        service_info: &ServiceInfo,
99        limit_datetime: Option<&DateTime<Utc>>,
100        status: Option<Status>,
101    ) -> Result<Vec<LLMDriftServerRecord>, SqlError> {
102        let mut query_string = Queries::GetLLMDriftRecords.get_query().sql;
103
104        let mut bind_count = 3;
105
106        if limit_datetime.is_some() {
107            bind_count += 1;
108            query_string.push_str(&format!(" AND created_at > ${bind_count}"));
109        }
110
111        let status_value = status.as_ref().and_then(|s| s.as_str());
112        if status_value.is_some() {
113            bind_count += 1;
114            query_string.push_str(&format!(" AND status = ${bind_count}"));
115        }
116
117        let mut query = sqlx::query_as::<_, LLMDriftServerSQLRecord>(&query_string)
118            .bind(&service_info.space)
119            .bind(&service_info.name)
120            .bind(&service_info.version);
121
122        if let Some(datetime) = limit_datetime {
123            query = query.bind(datetime);
124        }
125        // Bind status if provided
126        if let Some(status) = status_value {
127            query = query.bind(status);
128        }
129
130        let records = query.fetch_all(pool).await.map_err(SqlError::SqlxError)?;
131
132        Ok(records
133            .into_iter()
134            .map(LLMDriftServerRecord::from)
135            .collect())
136    }
137
138    /// Retrieves a paginated list of LLM drift records from the database
139    /// for a given service.
140    /// # Arguments
141    /// * `pool` - The database connection pool
142    /// * `service_info` - The service information to filter records by
143    /// * `status` - Optional status filter for the records
144    /// * `pagination` - The pagination request containing limit and cursor
145    /// # Returns
146    /// * A result containing a pagination response with LLM drift records or an error
147    #[instrument(skip_all)]
148    async fn get_llm_drift_records_pagination(
149        pool: &Pool<Postgres>,
150        service_info: &ServiceInfo,
151        status: Option<Status>,
152        pagination: PaginationRequest,
153    ) -> Result<PaginationResponse<LLMDriftServerRecord>, SqlError> {
154        let limit = pagination.limit.clamp(1, 100); // Cap at 100, min 1
155        let query_limit = limit + 1;
156
157        // Get initial SQL query
158        let mut sql = Queries::GetLLMDriftRecords.get_query().sql;
159        let mut bind_count = 3;
160
161        // If querying any page other than the first, we need to add a cursor condition
162        // Everything is filtered by ID desc (most recent), so if last ID is provided, we need to filter for IDs less than that
163        if pagination.cursor.is_some() {
164            bind_count += 1;
165            sql.push_str(&format!(" AND id < ${bind_count}"));
166        }
167
168        // Optional status filter
169        let status_value = status.as_ref().and_then(|s| s.as_str());
170        if status_value.is_some() {
171            bind_count += 1;
172            sql.push_str(&format!(" AND status = ${bind_count}"));
173        }
174
175        sql.push_str(&format!(" ORDER BY id DESC LIMIT ${}", bind_count + 1));
176
177        let mut query = sqlx::query_as::<_, LLMDriftServerSQLRecord>(&sql)
178            .bind(&service_info.space)
179            .bind(&service_info.name)
180            .bind(&service_info.version);
181
182        // Bind cursor parameter
183        if let Some(cursor) = &pagination.cursor {
184            query = query.bind(cursor.id);
185        }
186
187        // Bind status if provided
188        if let Some(status) = status_value {
189            query = query.bind(status);
190        }
191
192        // Bind limit
193        query = query.bind(query_limit);
194
195        let mut records = query.fetch_all(pool).await.map_err(SqlError::SqlxError)?;
196
197        // Check if there are more records
198        let has_more = records.len() > limit as usize;
199        if has_more {
200            records.pop(); // Remove the extra record
201        }
202
203        let next_cursor = if has_more && !records.is_empty() {
204            let last_record = records.last().unwrap();
205            Some(PaginationCursor { id: last_record.id })
206        } else {
207            None
208        };
209
210        let items = records
211            .into_iter()
212            .map(LLMDriftServerRecord::from)
213            .collect();
214
215        Ok(PaginationResponse {
216            items,
217            next_cursor,
218            has_more,
219        })
220    }
221
222    async fn get_llm_metric_values(
223        pool: &Pool<Postgres>,
224        service_info: &ServiceInfo,
225        limit_datetime: &DateTime<Utc>,
226        metrics: &[String],
227    ) -> Result<HashMap<String, f64>, SqlError> {
228        let query = Queries::GetLLMMetricValues.get_query();
229
230        let records = sqlx::query(&query.sql)
231            .bind(limit_datetime)
232            .bind(&service_info.space)
233            .bind(&service_info.name)
234            .bind(&service_info.version)
235            .bind(metrics)
236            .fetch_all(pool)
237            .await
238            .map_err(SqlError::SqlxError)?;
239
240        let metric_map = records
241            .into_iter()
242            .map(|row| {
243                let metric = row.get("metric");
244                let value = row.get("value");
245                (metric, value)
246            })
247            .collect();
248
249        Ok(metric_map)
250    }
251
252    // Queries the database for LLM metric records based on a time window
253    /// and aggregation.
254    ///
255    /// # Arguments
256    /// * `pool` - The database connection pool
257    /// * `params` - The drift request parameters
258    ///
259    /// # Returns
260    /// * BinnedMetrics
261    #[instrument(skip_all)]
262    async fn get_records(
263        pool: &Pool<Postgres>,
264        params: &DriftRequest,
265        minutes: i32,
266    ) -> Result<BinnedMetrics, SqlError> {
267        let bin = params.time_interval.to_minutes() as f64 / params.max_data_points as f64;
268
269        let query = Queries::GetBinnedMetrics.get_query();
270
271        let records: Vec<BinnedMetricWrapper> = sqlx::query_as(&query.sql)
272            .bind(bin)
273            .bind(minutes)
274            .bind(&params.space)
275            .bind(&params.name)
276            .bind(&params.version)
277            .fetch_all(pool)
278            .await
279            .map_err(SqlError::SqlxError)?;
280
281        Ok(BinnedMetrics::from_vec(
282            records.into_iter().map(|wrapper| wrapper.0).collect(),
283        ))
284    }
285
286    /// Helper for merging custom drift records
287    fn merge_feature_results(
288        results: BinnedMetrics,
289        map: &mut BinnedMetrics,
290    ) -> Result<(), SqlError> {
291        for (name, metric) in results.metrics {
292            let metric_clone = metric.clone();
293            map.metrics
294                .entry(name)
295                .and_modify(|existing| {
296                    existing.created_at.extend(metric_clone.created_at);
297                    existing.stats.extend(metric_clone.stats);
298                })
299                .or_insert(metric);
300        }
301
302        Ok(())
303    }
304
305    /// DataFusion implementation for getting custom drift records from archived data.
306    ///
307    /// # Arguments
308    /// * `params` - The drift request parameters
309    /// * `begin` - The start time of the time window
310    /// * `end` - The end time of the time window
311    /// * `minutes` - The number of minutes to bin the data
312    /// * `storage_settings` - The object storage settings
313    ///
314    /// # Returns
315    /// * A vector of drift records
316    #[instrument(skip_all)]
317    async fn get_archived_metric_records(
318        params: &DriftRequest,
319        begin: DateTime<Utc>,
320        end: DateTime<Utc>,
321        minutes: i32,
322        storage_settings: &ObjectStorageSettings,
323    ) -> Result<BinnedMetrics, SqlError> {
324        debug!("Getting archived LLM metrics for params: {:?}", params);
325        let path = format!(
326            "{}/{}/{}/llm_metric",
327            params.space, params.name, params.version
328        );
329        let bin = minutes as f64 / params.max_data_points as f64;
330        let archived_df = ParquetDataFrame::new(storage_settings, &RecordType::LLMMetric)?
331            .get_binned_metrics(
332                &path,
333                &bin,
334                &begin,
335                &end,
336                &params.space,
337                &params.name,
338                &params.version,
339            )
340            .await
341            .inspect_err(|e| {
342                error!("Failed to get archived LLM metrics: {:?}", e);
343            })?;
344
345        Ok(BinnedMetricsExtractor::dataframe_to_binned_metrics(archived_df).await?)
346    }
347
348    // Queries the database for drift records based on a time window and aggregation
349    //
350    // # Arguments
351    //
352    // * `name` - The name of the service to query drift records for
353    // * `params` - The drift request parameters
354    // # Returns
355    //
356    // * A vector of drift records
357    #[instrument(skip_all)]
358    async fn get_binned_llm_metric_values(
359        pool: &Pool<Postgres>,
360        params: &DriftRequest,
361        retention_period: &i32,
362        storage_settings: &ObjectStorageSettings,
363    ) -> Result<BinnedMetrics, SqlError> {
364        debug!("Getting binned Custom drift records for {:?}", params);
365
366        if !params.has_custom_interval() {
367            debug!("No custom interval provided, using default");
368            let minutes = params.time_interval.to_minutes();
369            return Self::get_records(pool, params, minutes).await;
370        }
371
372        debug!("Custom interval provided, using custom interval");
373        let interval = params.clone().to_custom_interval().unwrap();
374        let timestamps = split_custom_interval(interval.start, interval.end, retention_period)?;
375        let mut custom_metric_map = BinnedMetrics::default();
376
377        // get data from postgres
378        if let Some(minutes) = timestamps.current_minutes {
379            let current_results = Self::get_records(pool, params, minutes).await?;
380            Self::merge_feature_results(current_results, &mut custom_metric_map)?;
381        }
382
383        // get archived data
384        if let Some((archive_begin, archive_end)) = timestamps.archived_range {
385            if let Some(archived_minutes) = timestamps.archived_minutes {
386                let archived_results = Self::get_archived_metric_records(
387                    params,
388                    archive_begin,
389                    archive_end,
390                    archived_minutes,
391                    storage_settings,
392                )
393                .await?;
394                Self::merge_feature_results(archived_results, &mut custom_metric_map)?;
395            }
396        }
397
398        Ok(custom_metric_map)
399    }
400
401    /// Retrieves the next pending LLM drift task from drift_records.
402    async fn get_pending_llm_drift_record(
403        pool: &Pool<Postgres>,
404    ) -> Result<Option<LLMRecord>, SqlError> {
405        let query = Queries::GetPendingLLMDriftTask.get_query();
406        let result: Option<LLMRecordWrapper> = sqlx::query_as(&query.sql)
407            .fetch_optional(pool)
408            .await
409            .map_err(SqlError::SqlxError)?;
410
411        Ok(result.map(|wrapper| wrapper.0))
412    }
413
414    #[instrument(skip_all)]
415    async fn update_llm_drift_record_status(
416        pool: &Pool<Postgres>,
417        record: &LLMRecord,
418        status: Status,
419        workflow_duration: Option<i32>, // Duration in seconds
420    ) -> Result<(), SqlError> {
421        let query = Queries::UpdateLLMDriftTask.get_query();
422
423        let _query_result = sqlx::query(&query.sql)
424            .bind(status.as_str())
425            .bind(record.score.clone())
426            .bind(workflow_duration)
427            .bind(&record.uid)
428            .execute(pool)
429            .await
430            .inspect_err(|e| {
431                error!("Failed to update LLM drift record status: {:?}", e);
432            })?;
433
434        Ok(())
435    }
436}