swiftide_query/evaluators/
ragas.rs1use anyhow::Result;
37use async_trait::async_trait;
38use serde::{Deserialize, Serialize};
39use serde_json::json;
40use std::{collections::HashMap, str::FromStr, sync::Arc};
41use tokio::sync::RwLock;
42
43use swiftide_core::{
44 querying::{states, Query, QueryEvaluation},
45 EvaluateQuery,
46};
47
48#[derive(Debug, Clone)]
50pub struct Ragas {
51 dataset: Arc<RwLock<EvaluationDataSet>>,
52}
53
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56pub struct EvaluationData {
57 question: String,
58 answer: String,
59 contexts: Vec<String>,
60 ground_truth: String,
61}
62
63#[derive(Debug, Clone)]
65pub struct EvaluationDataSet(HashMap<String, EvaluationData>);
66
67impl Ragas {
68 pub fn from_prepared_questions(questions: impl Into<EvaluationDataSet>) -> Self {
71 Ragas {
72 dataset: Arc::new(RwLock::new(questions.into())),
73 }
74 }
75
76 pub async fn questions(&self) -> Vec<Query<states::Pending>> {
77 self.dataset.read().await.0.keys().map(Into::into).collect()
78 }
79
80 pub async fn record_answers_as_ground_truth(&self) {
82 self.dataset.write().await.record_answers_as_ground_truth();
83 }
84
85 pub async fn to_json(&self) -> String {
87 self.dataset.read().await.to_json()
88 }
89}
90
91#[async_trait]
92impl EvaluateQuery for Ragas {
93 #[tracing::instrument(skip_all)]
94 async fn evaluate(&self, query: QueryEvaluation) -> Result<()> {
95 let mut dataset = self.dataset.write().await;
96 dataset.upsert_evaluation(&query)
97 }
98}
99
100impl EvaluationDataSet {
101 pub(crate) fn record_answers_as_ground_truth(&mut self) {
102 for data in self.0.values_mut() {
103 data.ground_truth.clone_from(&data.answer);
104 }
105 }
106
107 pub(crate) fn upsert_evaluation(&mut self, query: &QueryEvaluation) -> Result<()> {
108 match query {
109 QueryEvaluation::RetrieveDocuments(query) => self.upsert_retrieved_documents(query),
110 QueryEvaluation::AnswerQuery(query) => self.upsert_answer(query),
111 }
112 }
113
114 fn upsert_retrieved_documents(&mut self, query: &Query<states::Retrieved>) -> Result<()> {
116 let question = query.original();
117 let data = self
118 .0
119 .get_mut(question)
120 .ok_or_else(|| anyhow::anyhow!("Question not found"))?;
121
122 data.contexts = query
123 .documents()
124 .iter()
125 .map(|d| d.content().to_string())
126 .collect::<Vec<_>>();
127 Ok(())
128 }
129
130 fn upsert_answer(&mut self, query: &Query<states::Answered>) -> Result<()> {
131 let question = query.original();
132 let data = self
133 .0
134 .get_mut(question)
135 .ok_or_else(|| anyhow::anyhow!("Question not found"))?;
136
137 data.answer = query.answer().to_string();
138
139 Ok(())
140 }
141
142 pub(crate) fn to_json(&self) -> String {
163 let json_value = json!(self.0.values().collect::<Vec<_>>());
164 serde_json::to_string_pretty(&json_value).unwrap_or_else(|_| json_value.to_string())
165 }
166}
167
168impl From<Vec<String>> for EvaluationDataSet {
170 fn from(val: Vec<String>) -> Self {
171 EvaluationDataSet(
172 val.into_iter()
173 .map(|question| {
174 (
175 question.clone(),
176 EvaluationData {
177 question,
178 ..EvaluationData::default()
179 },
180 )
181 })
182 .collect(),
183 )
184 }
185}
186
187impl From<&[String]> for EvaluationDataSet {
188 fn from(val: &[String]) -> Self {
189 EvaluationDataSet(
190 val.iter()
191 .map(|question| {
192 (
193 question.to_string(),
194 EvaluationData {
195 question: question.to_string(),
196 ..EvaluationData::default()
197 },
198 )
199 })
200 .collect(),
201 )
202 }
203}
204
205impl From<Vec<(String, String)>> for EvaluationDataSet {
207 fn from(val: Vec<(String, String)>) -> Self {
208 EvaluationDataSet(
209 val.into_iter()
210 .map(|(question, ground_truth)| {
211 (
212 question.clone(),
213 EvaluationData {
214 question,
215 ground_truth,
216 ..EvaluationData::default()
217 },
218 )
219 })
220 .collect(),
221 )
222 }
223}
224
225impl FromStr for EvaluationDataSet {
227 type Err = serde_json::Error;
228
229 fn from_str(val: &str) -> std::prelude::v1::Result<Self, Self::Err> {
230 let data: Vec<EvaluationData> = serde_json::from_str(val)?;
231 Ok(EvaluationDataSet(
232 data.into_iter()
233 .map(|data| (data.question.clone(), data))
234 .collect(),
235 ))
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use std::sync::Arc;
243 use swiftide_core::querying::{Query, QueryEvaluation};
244 use tokio::sync::RwLock;
245
246 #[tokio::test]
247 async fn test_ragas_from_prepared_questions() {
248 let questions = vec!["What is Rust?".to_string(), "What is Tokio?".to_string()];
249 let ragas = Ragas::from_prepared_questions(questions.clone());
250
251 let stored_questions = ragas.questions().await;
252 assert_eq!(stored_questions.len(), questions.len());
253
254 for question in questions {
255 assert!(stored_questions.iter().any(|q| q.original() == question));
256 }
257 }
258
259 #[tokio::test]
260 async fn test_ragas_record_answers_as_ground_truth() {
261 let dataset = Arc::new(RwLock::new(EvaluationDataSet::from(vec![(
262 "What is Rust?".to_string(),
263 "A programming language".to_string(),
264 )])));
265 let ragas = Ragas {
266 dataset: dataset.clone(),
267 };
268
269 {
270 let mut lock = dataset.write().await;
271 let data = lock.0.get_mut("What is Rust?").unwrap();
272 data.answer = "A systems programming language".to_string();
273 }
274
275 ragas.record_answers_as_ground_truth().await;
276
277 let updated_data = ragas.dataset.read().await;
278 let data = updated_data.0.get("What is Rust?").unwrap();
279 assert_eq!(data.ground_truth, "A systems programming language");
280 }
281
282 #[tokio::test]
283 async fn test_ragas_to_json() {
284 let dataset = EvaluationDataSet::from(vec![(
285 "What is Rust?".to_string(),
286 "A programming language".to_string(),
287 )]);
288 let ragas = Ragas {
289 dataset: Arc::new(RwLock::new(dataset)),
290 };
291
292 let json_output = ragas.to_json().await;
293 let expected_json = "[\n {\n \"answer\": \"\",\n \"contexts\": [],\n \"ground_truth\": \"A programming language\",\n \"question\": \"What is Rust?\"\n }\n]";
294 assert_eq!(json_output, expected_json);
295 }
296
297 #[tokio::test]
298 async fn test_evaluate_query_upsert_retrieved_documents() {
299 let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
300 let ragas = Ragas {
301 dataset: Arc::new(RwLock::new(dataset.clone())),
302 };
303
304 let query = Query::builder()
305 .original("What is Rust?")
306 .documents(vec!["Rust is a language".into()])
307 .build()
308 .unwrap();
309 let evaluation = QueryEvaluation::RetrieveDocuments(query.clone());
310
311 ragas.evaluate(evaluation).await.unwrap();
312
313 let updated_data = ragas.dataset.read().await;
314 let data = updated_data.0.get("What is Rust?").unwrap();
315 assert_eq!(data.contexts, vec!["Rust is a language"]);
316 }
317
318 #[tokio::test]
319 async fn test_evaluate_query_upsert_answer() {
320 let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
321 let ragas = Ragas {
322 dataset: Arc::new(RwLock::new(dataset.clone())),
323 };
324
325 let query = Query::builder()
326 .original("What is Rust?")
327 .current("A systems programming language")
328 .build()
329 .unwrap();
330 let evaluation = QueryEvaluation::AnswerQuery(query.clone());
331
332 ragas.evaluate(evaluation).await.unwrap();
333
334 let updated_data = ragas.dataset.read().await;
335 let data = updated_data.0.get("What is Rust?").unwrap();
336 assert_eq!(data.answer, "A systems programming language");
337 }
338
339 #[tokio::test]
340 async fn test_evaluation_dataset_record_answers_as_ground_truth() {
341 let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
342 let data = dataset.0.get_mut("What is Rust?").unwrap();
343 data.answer = "A programming language".to_string();
344
345 dataset.record_answers_as_ground_truth();
346
347 let data = dataset.0.get("What is Rust?").unwrap();
348 assert_eq!(data.ground_truth, "A programming language");
349 }
350
351 #[tokio::test]
352 async fn test_evaluation_dataset_to_json() {
353 let dataset = EvaluationDataSet::from(vec![(
354 "What is Rust?".to_string(),
355 "A programming language".to_string(),
356 )]);
357
358 let json_output = dataset.to_json();
359 let expected_json = "[\n {\n \"answer\": \"\",\n \"contexts\": [],\n \"ground_truth\": \"A programming language\",\n \"question\": \"What is Rust?\"\n }\n]";
360 assert_eq!(json_output, expected_json);
361 }
362
363 #[tokio::test]
364 async fn test_evaluation_dataset_upsert_retrieved_documents() {
365 let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
366
367 let query = Query::builder()
368 .original("What is Rust?")
369 .documents(vec!["Rust is a language".into()])
370 .build()
371 .unwrap();
372 dataset
373 .upsert_evaluation(&QueryEvaluation::RetrieveDocuments(query.clone()))
374 .unwrap();
375
376 let data = dataset.0.get("What is Rust?").unwrap();
377 assert_eq!(data.contexts, vec!["Rust is a language"]);
378 }
379
380 #[tokio::test]
381 async fn test_evaluation_dataset_upsert_answer() {
382 let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
383
384 let query = Query::builder()
385 .original("What is Rust?")
386 .current("A systems programming language")
387 .build()
388 .unwrap();
389 dataset
390 .upsert_evaluation(&QueryEvaluation::AnswerQuery(query.clone()))
391 .unwrap();
392
393 let data = dataset.0.get("What is Rust?").unwrap();
394 assert_eq!(data.answer, "A systems programming language");
395 }
396}