scouter_events/queue/llm/
record_queue.rs1use crate::error::FeatureQueueError;
2use crate::queue::traits::FeatureQueue;
3use core::result::Result::Ok;
4use scouter_types::BoxedLLMDriftServerRecord;
5use scouter_types::LLMRecord;
6use scouter_types::QueueExt;
7use scouter_types::{
8 llm::LLMDriftProfile, LLMDriftServerRecord, MessageRecord, ServerRecord, ServerRecords,
9};
10use tracing::instrument;
11pub struct LLMRecordQueue {
12 drift_profile: LLMDriftProfile,
13 empty_queue: Vec<LLMRecord>,
14}
15
16impl LLMRecordQueue {
17 pub fn new(drift_profile: LLMDriftProfile) -> Self {
18 LLMRecordQueue {
19 drift_profile,
20 empty_queue: Vec::new(),
21 }
22 }
23
24 #[instrument(skip_all, name = "insert_llm")]
34 pub fn insert(
35 &self,
36 records: Vec<&LLMRecord>,
37 queue: &mut Vec<LLMRecord>,
38 ) -> Result<(), FeatureQueueError> {
39 for record in records {
40 queue.push(record.clone());
41 }
42 Ok(())
43 }
44
45 fn create_drift_records(
46 &self,
47 queue: Vec<LLMRecord>,
48 ) -> Result<ServerRecords, FeatureQueueError> {
49 let records = queue
50 .iter()
51 .map(|record| {
52 ServerRecord::LLMDrift(BoxedLLMDriftServerRecord::new(
53 LLMDriftServerRecord::new_rs(
54 self.drift_profile.config.space.clone(),
55 self.drift_profile.config.name.clone(),
56 self.drift_profile.config.version.clone(),
57 record.prompt.clone(),
58 record.context.clone(),
59 record.created_at,
60 record.uid.clone(),
61 record.score.clone(),
62 ),
63 )) })
65 .collect::<Vec<ServerRecord>>();
66
67 Ok(ServerRecords::new(records))
68 }
69}
70
71impl FeatureQueue for LLMRecordQueue {
72 fn create_drift_records_from_batch<T: QueueExt>(
73 &self,
74 batch: Vec<T>,
75 ) -> Result<MessageRecord, FeatureQueueError> {
76 let mut queue = self.empty_queue.clone();
78
79 for elem in batch {
80 self.insert(elem.llm_records(), &mut queue)?;
81 }
82
83 Ok(MessageRecord::ServerRecords(
84 self.create_drift_records(queue)?,
85 ))
86 }
87}
88
89#[cfg(test)]
90mod tests {
91
92 use super::*;
93 use potato_head::create_score_prompt;
94 use scouter_types::llm::{LLMAlertConfig, LLMDriftConfig, LLMDriftMetric, LLMDriftProfile};
95 use scouter_types::AlertThreshold;
96
97 async fn get_test_drift_profile() -> LLMDriftProfile {
98 let prompt = create_score_prompt(Some(vec!["input".to_string()]));
99 let metric1 = LLMDriftMetric::new(
100 "coherence",
101 5.0,
102 AlertThreshold::Below,
103 Some(0.5),
104 Some(prompt.clone()),
105 )
106 .unwrap();
107
108 let metric2 = LLMDriftMetric::new(
109 "relevancy",
110 5.0,
111 AlertThreshold::Below,
112 None,
113 Some(prompt.clone()),
114 )
115 .unwrap();
116
117 let alert_config = LLMAlertConfig::default();
118 let drift_config =
119 LLMDriftConfig::new("scouter", "ML", "0.1.0", 25, alert_config, None).unwrap();
120
121 LLMDriftProfile::from_metrics(drift_config, vec![metric1, metric2])
122 .await
123 .unwrap()
124 }
125
126 #[test]
127 fn test_feature_queue_llm_insert_record() {
128 let runtime = tokio::runtime::Runtime::new().unwrap();
129 let drift_profile = runtime.block_on(async { get_test_drift_profile().await });
130 let feature_queue = LLMRecordQueue::new(drift_profile);
131
132 assert_eq!(feature_queue.empty_queue.len(), 0);
133
134 let mut record_batch = Vec::new();
135 for _ in 0..1 {
136 let mut new_map = serde_json::Map::new();
137 new_map.insert("input".into(), serde_json::Value::String("test".into()));
139 let context = serde_json::Value::Object(new_map);
140
141 let record = LLMRecord::new_rs(Some(context), None);
142 record_batch.push(record);
143 }
144
145 let records = feature_queue
146 .create_drift_records_from_batch(record_batch)
147 .unwrap();
148
149 assert_eq!(records.len(), 1);
151 }
152}