scouter_drift/llm/
poller.rs

1// Module for polling LLM drift records that are "pending" and need to be processed
2use crate::error::DriftError;
3use crate::llm::evaluator::LLMEvaluator;
4use potato_head::Score;
5use scouter_sql::sql::traits::{LLMDriftSqlLogic, ProfileSqlLogic};
6use scouter_sql::PostgresClient;
7use scouter_types::llm::LLMDriftProfile;
8use scouter_types::{DriftType, GetProfileRequest, LLMRecord, Status};
9use sqlx::{Pool, Postgres};
10use std::collections::HashMap;
11use std::time::Duration;
12use tokio::time::sleep;
13use tracing::{debug, error, info, instrument};
14
15pub struct LLMPoller {
16    db_pool: Pool<Postgres>,
17    max_retries: usize,
18}
19
20impl LLMPoller {
21    pub fn new(db_pool: &Pool<Postgres>, max_retries: usize) -> Self {
22        LLMPoller {
23            db_pool: db_pool.clone(),
24            max_retries,
25        }
26    }
27
28    #[instrument(skip_all)]
29    pub async fn process_drift_record(
30        &mut self,
31        record: &LLMRecord,
32        profile: &LLMDriftProfile,
33    ) -> Result<(HashMap<String, Score>, Option<i32>), DriftError> {
34        debug!("Processing workflow");
35
36        match LLMEvaluator::process_drift_record(record, profile).await {
37            Ok((metrics, score_map, workflow_duration)) => {
38                PostgresClient::insert_llm_metric_values_batch(&self.db_pool, &metrics)
39                    .await
40                    .inspect_err(|e| {
41                        error!("Failed to insert LLM metric values: {:?}", e);
42                    })?;
43
44                return Ok((score_map, workflow_duration));
45            }
46            Err(e) => {
47                error!("Failed to process drift record: {:?}", e);
48                return Err(DriftError::LLMEvaluatorError(e.to_string()));
49            }
50        };
51    }
52
53    #[instrument(skip_all)]
54    pub async fn do_poll(&mut self) -> Result<bool, DriftError> {
55        // Get task from the database (query uses skip lock to pull task and update to processing)
56        let task = PostgresClient::get_pending_llm_drift_record(&self.db_pool).await?;
57
58        let Some(mut task) = task else {
59            return Ok(false);
60        };
61
62        info!(
63            "Processing llm drift record for profile: {}/{}/{}",
64            task.space, task.name, task.version
65        );
66
67        // get get/load profile and reset agents
68        let request = GetProfileRequest {
69            space: task.space.clone(),
70            name: task.name.clone(),
71            version: task.version.clone(),
72            drift_type: DriftType::LLM,
73        };
74
75        let mut llm_profile = if let Some(profile) =
76            PostgresClient::get_drift_profile(&self.db_pool, &request).await?
77        {
78            let llm_profile: LLMDriftProfile =
79                serde_json::from_value(profile).inspect_err(|e| {
80                    error!("Failed to deserialize LLM drift profile: {:?}", e);
81                })?;
82            llm_profile
83        } else {
84            error!(
85                "No LLM drift profile found for {}/{}/{}",
86                task.space, task.name, task.version
87            );
88            return Ok(false);
89        };
90        let mut retry_count = 0;
91
92        llm_profile.workflow.reset_agents().await.inspect_err(|e| {
93            error!("Failed to reset agents: {:?}", e);
94        })?;
95
96        loop {
97            match self.process_drift_record(&task, &llm_profile).await {
98                Ok((result, workflow_duration)) => {
99                    task.score = serde_json::to_value(result).inspect_err(|e| {
100                        error!("Failed to serialize score map: {:?}", e);
101                    })?;
102
103                    PostgresClient::update_llm_drift_record_status(
104                        &self.db_pool,
105                        &task,
106                        Status::Processed,
107                        workflow_duration,
108                    )
109                    .await?;
110                    break;
111                }
112                Err(e) => {
113                    error!(
114                        "Failed to process drift record (attempt {}): {:?}",
115                        retry_count + 1,
116                        e
117                    );
118
119                    retry_count += 1;
120                    if retry_count >= self.max_retries {
121                        // Update the record status to error
122                        PostgresClient::update_llm_drift_record_status(
123                            &self.db_pool,
124                            &task,
125                            Status::Failed,
126                            None,
127                        )
128                        .await?;
129                        return Err(DriftError::LLMEvaluatorError(e.to_string()));
130                    } else {
131                        // Exponential backoff before retrying
132                        sleep(Duration::from_millis(100 * 2_u64.pow(retry_count as u32))).await;
133                    }
134                }
135            }
136        }
137
138        Ok(true)
139    }
140
141    #[instrument(skip_all)]
142    pub async fn poll_for_tasks(&mut self) -> Result<(), DriftError> {
143        let result = self.do_poll().await;
144
145        // silent error handling
146        match result {
147            Ok(true) => {
148                debug!("Successfully processed drift record");
149                Ok(())
150            }
151            Ok(false) => Ok(()),
152            Err(e) => {
153                error!("Error processing drift record: {:?}", e);
154                Ok(())
155            }
156        }
157    }
158}