swiftide_query/evaluators/
ragas.rs1use anyhow::Result;
37use async_trait::async_trait;
38use serde::{Deserialize, Serialize};
39use serde_json::json;
40use std::{collections::HashMap, str::FromStr, sync::Arc};
41use tokio::sync::RwLock;
42
43use swiftide_core::{
44 querying::{states, Query, QueryEvaluation},
45 EvaluateQuery,
46};
47
48#[derive(Debug, Clone)]
50pub struct Ragas {
51 dataset: Arc<RwLock<EvaluationDataSet>>,
52}
53
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56pub struct EvaluationData {
57 question: String,
58 answer: String,
59 contexts: Vec<String>,
60 ground_truth: String,
61}
62
63#[derive(Debug, Clone)]
65pub struct EvaluationDataSet(HashMap<String, EvaluationData>);
66
67impl Ragas {
68 pub fn from_prepared_questions(questions: impl Into<EvaluationDataSet>) -> Self {
71 Ragas {
72 dataset: Arc::new(RwLock::new(questions.into())),
73 }
74 }
75
76 pub async fn questions(&self) -> Vec<Query<states::Pending>> {
77 self.dataset.read().await.0.keys().map(Into::into).collect()
78 }
79
80 pub async fn record_answers_as_ground_truth(&self) {
82 self.dataset.write().await.record_answers_as_ground_truth();
83 }
84
85 pub async fn to_json(&self) -> String {
87 self.dataset.read().await.to_json()
88 }
89}
90
91#[async_trait]
92impl EvaluateQuery for Ragas {
93 #[tracing::instrument(skip_all)]
94 async fn evaluate(&self, query: QueryEvaluation) -> Result<()> {
95 let mut dataset = self.dataset.write().await;
96 dataset.upsert_evaluation(&query)
97 }
98}
99
100impl EvaluationDataSet {
101 pub(crate) fn record_answers_as_ground_truth(&mut self) {
102 for data in self.0.values_mut() {
103 data.ground_truth.clone_from(&data.answer);
104 }
105 }
106
107 pub(crate) fn upsert_evaluation(&mut self, query: &QueryEvaluation) -> Result<()> {
108 match query {
109 QueryEvaluation::RetrieveDocuments(query) => self.upsert_retrieved_documents(query),
110 QueryEvaluation::AnswerQuery(query) => self.upsert_answer(query),
111 }
112 }
113
114 fn upsert_retrieved_documents(&mut self, query: &Query<states::Retrieved>) -> Result<()> {
116 let question = query.original();
117 let data = self
118 .0
119 .get_mut(question)
120 .ok_or_else(|| anyhow::anyhow!("Question not found"))?;
121
122 data.contexts = query
123 .documents()
124 .iter()
125 .map(|d| d.content().to_string())
126 .collect::<Vec<_>>();
127 Ok(())
128 }
129
130 fn upsert_answer(&mut self, query: &Query<states::Answered>) -> Result<()> {
131 let question = query.original();
132 let data = self
133 .0
134 .get_mut(question)
135 .ok_or_else(|| anyhow::anyhow!("Question not found"))?;
136
137 data.answer = query.answer().to_string();
138
139 Ok(())
140 }
141
142 pub(crate) fn to_json(&self) -> String {
163 json!(self.0.values().collect::<Vec<_>>()).to_string()
164 }
165}
166
167impl From<Vec<String>> for EvaluationDataSet {
169 fn from(val: Vec<String>) -> Self {
170 EvaluationDataSet(
171 val.into_iter()
172 .map(|question| {
173 (
174 question.clone(),
175 EvaluationData {
176 question,
177 ..EvaluationData::default()
178 },
179 )
180 })
181 .collect(),
182 )
183 }
184}
185
186impl From<&[String]> for EvaluationDataSet {
187 fn from(val: &[String]) -> Self {
188 EvaluationDataSet(
189 val.iter()
190 .map(|question| {
191 (
192 question.to_string(),
193 EvaluationData {
194 question: question.to_string(),
195 ..EvaluationData::default()
196 },
197 )
198 })
199 .collect(),
200 )
201 }
202}
203
204impl From<Vec<(String, String)>> for EvaluationDataSet {
206 fn from(val: Vec<(String, String)>) -> Self {
207 EvaluationDataSet(
208 val.into_iter()
209 .map(|(question, ground_truth)| {
210 (
211 question.clone(),
212 EvaluationData {
213 question,
214 ground_truth,
215 ..EvaluationData::default()
216 },
217 )
218 })
219 .collect(),
220 )
221 }
222}
223
224impl FromStr for EvaluationDataSet {
226 type Err = serde_json::Error;
227
228 fn from_str(val: &str) -> std::prelude::v1::Result<Self, Self::Err> {
229 let data: Vec<EvaluationData> = serde_json::from_str(val)?;
230 Ok(EvaluationDataSet(
231 data.into_iter()
232 .map(|data| (data.question.clone(), data))
233 .collect(),
234 ))
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use std::sync::Arc;
242 use swiftide_core::querying::{Query, QueryEvaluation};
243 use tokio::sync::RwLock;
244
245 #[tokio::test]
246 async fn test_ragas_from_prepared_questions() {
247 let questions = vec!["What is Rust?".to_string(), "What is Tokio?".to_string()];
248 let ragas = Ragas::from_prepared_questions(questions.clone());
249
250 let stored_questions = ragas.questions().await;
251 assert_eq!(stored_questions.len(), questions.len());
252
253 for question in questions {
254 assert!(stored_questions.iter().any(|q| q.original() == question));
255 }
256 }
257
258 #[tokio::test]
259 async fn test_ragas_record_answers_as_ground_truth() {
260 let dataset = Arc::new(RwLock::new(EvaluationDataSet::from(vec![(
261 "What is Rust?".to_string(),
262 "A programming language".to_string(),
263 )])));
264 let ragas = Ragas {
265 dataset: dataset.clone(),
266 };
267
268 {
269 let mut lock = dataset.write().await;
270 let data = lock.0.get_mut("What is Rust?").unwrap();
271 data.answer = "A systems programming language".to_string();
272 }
273
274 ragas.record_answers_as_ground_truth().await;
275
276 let updated_data = ragas.dataset.read().await;
277 let data = updated_data.0.get("What is Rust?").unwrap();
278 assert_eq!(data.ground_truth, "A systems programming language");
279 }
280
281 #[tokio::test]
282 async fn test_ragas_to_json() {
283 let dataset = EvaluationDataSet::from(vec![(
284 "What is Rust?".to_string(),
285 "A programming language".to_string(),
286 )]);
287 let ragas = Ragas {
288 dataset: Arc::new(RwLock::new(dataset)),
289 };
290
291 let json_output = ragas.to_json().await;
292 let expected_json = "[{\"answer\":\"\",\"contexts\":[],\"ground_truth\":\"A programming language\",\"question\":\"What is Rust?\"}]";
293 assert_eq!(json_output, expected_json);
294 }
295
296 #[tokio::test]
297 async fn test_evaluate_query_upsert_retrieved_documents() {
298 let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
299 let ragas = Ragas {
300 dataset: Arc::new(RwLock::new(dataset.clone())),
301 };
302
303 let query = Query::builder()
304 .original("What is Rust?")
305 .documents(vec!["Rust is a language".into()])
306 .build()
307 .unwrap();
308 let evaluation = QueryEvaluation::RetrieveDocuments(query.clone());
309
310 ragas.evaluate(evaluation).await.unwrap();
311
312 let updated_data = ragas.dataset.read().await;
313 let data = updated_data.0.get("What is Rust?").unwrap();
314 assert_eq!(data.contexts, vec!["Rust is a language"]);
315 }
316
317 #[tokio::test]
318 async fn test_evaluate_query_upsert_answer() {
319 let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
320 let ragas = Ragas {
321 dataset: Arc::new(RwLock::new(dataset.clone())),
322 };
323
324 let query = Query::builder()
325 .original("What is Rust?")
326 .current("A systems programming language")
327 .build()
328 .unwrap();
329 let evaluation = QueryEvaluation::AnswerQuery(query.clone());
330
331 ragas.evaluate(evaluation).await.unwrap();
332
333 let updated_data = ragas.dataset.read().await;
334 let data = updated_data.0.get("What is Rust?").unwrap();
335 assert_eq!(data.answer, "A systems programming language");
336 }
337
338 #[tokio::test]
339 async fn test_evaluation_dataset_record_answers_as_ground_truth() {
340 let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
341 let data = dataset.0.get_mut("What is Rust?").unwrap();
342 data.answer = "A programming language".to_string();
343
344 dataset.record_answers_as_ground_truth();
345
346 let data = dataset.0.get("What is Rust?").unwrap();
347 assert_eq!(data.ground_truth, "A programming language");
348 }
349
350 #[tokio::test]
351 async fn test_evaluation_dataset_to_json() {
352 let dataset = EvaluationDataSet::from(vec![(
353 "What is Rust?".to_string(),
354 "A programming language".to_string(),
355 )]);
356
357 let json_output = dataset.to_json();
358 let expected_json = "[{\"answer\":\"\",\"contexts\":[],\"ground_truth\":\"A programming language\",\"question\":\"What is Rust?\"}]";
359 assert_eq!(json_output, expected_json);
360 }
361
362 #[tokio::test]
363 async fn test_evaluation_dataset_upsert_retrieved_documents() {
364 let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
365
366 let query = Query::builder()
367 .original("What is Rust?")
368 .documents(vec!["Rust is a language".into()])
369 .build()
370 .unwrap();
371 dataset
372 .upsert_evaluation(&QueryEvaluation::RetrieveDocuments(query.clone()))
373 .unwrap();
374
375 let data = dataset.0.get("What is Rust?").unwrap();
376 assert_eq!(data.contexts, vec!["Rust is a language"]);
377 }
378
379 #[tokio::test]
380 async fn test_evaluation_dataset_upsert_answer() {
381 let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
382
383 let query = Query::builder()
384 .original("What is Rust?")
385 .current("A systems programming language")
386 .build()
387 .unwrap();
388 dataset
389 .upsert_evaluation(&QueryEvaluation::AnswerQuery(query.clone()))
390 .unwrap();
391
392 let data = dataset.0.get("What is Rust?").unwrap();
393 assert_eq!(data.answer, "A systems programming language");
394 }
395}