scouter_events/queue/llm/
record_queue.rs

1use crate::error::FeatureQueueError;
2use crate::queue::traits::FeatureQueue;
3use core::result::Result::Ok;
4use scouter_types::BoxedLLMDriftServerRecord;
5use scouter_types::LLMRecord;
6use scouter_types::QueueExt;
7use scouter_types::{
8    llm::LLMDriftProfile, LLMDriftServerRecord, MessageRecord, ServerRecord, ServerRecords,
9};
10use tracing::instrument;
11pub struct LLMRecordQueue {
12    drift_profile: LLMDriftProfile,
13    empty_queue: Vec<LLMRecord>,
14}
15
16impl LLMRecordQueue {
17    pub fn new(drift_profile: LLMDriftProfile) -> Self {
18        LLMRecordQueue {
19            drift_profile,
20            empty_queue: Vec::new(),
21        }
22    }
23
24    /// Insert llm records into the queue
25    ///
26    /// # Arguments
27    ///
28    /// * `records` - A vector of llm records to insert into the queue
29    ///
30    /// # Returns
31    ///
32    /// * `Result<(), FeatureQueueError>` - A result indicating success or failure
33    #[instrument(skip_all, name = "insert_llm")]
34    pub fn insert(
35        &self,
36        records: Vec<&LLMRecord>,
37        queue: &mut Vec<LLMRecord>,
38    ) -> Result<(), FeatureQueueError> {
39        for record in records {
40            queue.push(record.clone());
41        }
42        Ok(())
43    }
44
45    fn create_drift_records(
46        &self,
47        queue: Vec<LLMRecord>,
48    ) -> Result<ServerRecords, FeatureQueueError> {
49        let records = queue
50            .iter()
51            .map(|record| {
52                ServerRecord::LLMDrift(BoxedLLMDriftServerRecord::new(
53                    LLMDriftServerRecord::new_rs(
54                        self.drift_profile.config.space.clone(),
55                        self.drift_profile.config.name.clone(),
56                        self.drift_profile.config.version.clone(),
57                        record.prompt.clone(),
58                        record.context.clone(),
59                        record.created_at,
60                        record.uid.clone(),
61                        record.score.clone(),
62                    ),
63                )) // Removed the semicolon here
64            })
65            .collect::<Vec<ServerRecord>>();
66
67        Ok(ServerRecords::new(records))
68    }
69}
70
71impl FeatureQueue for LLMRecordQueue {
72    fn create_drift_records_from_batch<T: QueueExt>(
73        &self,
74        batch: Vec<T>,
75    ) -> Result<MessageRecord, FeatureQueueError> {
76        // clones the empty map (so we don't need to recreate it on each call)
77        let mut queue = self.empty_queue.clone();
78
79        for elem in batch {
80            self.insert(elem.llm_records(), &mut queue)?;
81        }
82
83        Ok(MessageRecord::ServerRecords(
84            self.create_drift_records(queue)?,
85        ))
86    }
87}
88
89#[cfg(test)]
90mod tests {
91
92    use super::*;
93    use potato_head::create_score_prompt;
94    use scouter_types::llm::{LLMAlertConfig, LLMDriftConfig, LLMDriftMetric, LLMDriftProfile};
95    use scouter_types::AlertThreshold;
96
97    async fn get_test_drift_profile() -> LLMDriftProfile {
98        let prompt = create_score_prompt(Some(vec!["input".to_string()]));
99        let metric1 = LLMDriftMetric::new(
100            "coherence",
101            5.0,
102            AlertThreshold::Below,
103            Some(0.5),
104            Some(prompt.clone()),
105        )
106        .unwrap();
107
108        let metric2 = LLMDriftMetric::new(
109            "relevancy",
110            5.0,
111            AlertThreshold::Below,
112            None,
113            Some(prompt.clone()),
114        )
115        .unwrap();
116
117        let alert_config = LLMAlertConfig::default();
118        let drift_config =
119            LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
120
121        LLMDriftProfile::from_metrics(drift_config, vec![metric1, metric2])
122            .await
123            .unwrap()
124    }
125
126    #[test]
127    fn test_feature_queue_llm_insert_record() {
128        let runtime = tokio::runtime::Runtime::new().unwrap();
129        let drift_profile = runtime.block_on(async { get_test_drift_profile().await });
130        let feature_queue = LLMRecordQueue::new(drift_profile);
131
132        assert_eq!(feature_queue.empty_queue.len(), 0);
133
134        let mut record_batch = Vec::new();
135        for _ in 0..1 {
136            let mut new_map = serde_json::Map::new();
137            // insert entry in map
138            new_map.insert("input".into(), serde_json::Value::String("test".into()));
139            let context = serde_json::Value::Object(new_map);
140
141            let record = LLMRecord::new_rs(Some(context), None);
142            record_batch.push(record);
143        }
144
145        let records = feature_queue
146            .create_drift_records_from_batch(record_batch)
147            .unwrap();
148
149        // empty should be excluded
150        assert_eq!(records.len(), 1);
151    }
152}