swiftide_query/evaluators/
ragas.rs

1//! The Ragas evaluator allows you to export a RAGAS compatible JSON dataset.
2//!
3//! RAGAS requires a ground truth to compare to. You can either record the answers for an initial
4//! dataset, or provide the ground truth yourself.
5//!
6//! Refer to the ragas documentation on how to use the dataset or take a look at a more involved
7//! example at [swiftide-tutorials](https://github.com/bosun-ai/swiftide-tutorial).
8//!
9//! # Example
10//!
11//! ```ignore
12//! # use swiftide_query::*;
13//! # use anyhow::{Result, Context};
14//! # #[tokio::main]
15//! # async fn main() -> anyhow::Result<()> {
16//!
17//! let openai = swiftide::integrations::openai::OpenAi::default();
18//! let qdrant = swiftide::integrations::qdrant::Qdrant::default();
19//!
20//! let ragas = evaluators::ragas::Ragas::from_prepared_questions(questions);
21//!
22//! let pipeline = query::Pipeline::default()
23//! .evaluate_with(ragas.clone())
24//! .then_transform_query(query_transformers::GenerateSubquestions::from_client(openai.clone()))
25//! .then_transform_query(query_transformers::Embed::from_client(
26//! openai.clone(),
27//! ))
28//! .then_retrieve(qdrant.clone())
29//! .then_answer(answers::Simple::from_client(openai.clone()));
30//!
31//! pipeline.query_all(ragas.questions().await).await.unwrap();
32//!
33//! std::fs::write("output.json", ragas.to_json().await).unwrap();
34//! # Ok(())
35//! # }
36use anyhow::Result;
37use async_trait::async_trait;
38use serde::{Deserialize, Serialize};
39use serde_json::json;
40use std::{collections::HashMap, str::FromStr, sync::Arc};
41use tokio::sync::RwLock;
42
43use swiftide_core::{
44    querying::{states, Query, QueryEvaluation},
45    EvaluateQuery,
46};
47
48/// Ragas evaluator to be used in a pipeline
49#[derive(Debug, Clone)]
50pub struct Ragas {
51    dataset: Arc<RwLock<EvaluationDataSet>>,
52}
53
54/// Row structure for RAGAS compatible JSON
55#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56pub struct EvaluationData {
57    question: String,
58    answer: String,
59    contexts: Vec<String>,
60    ground_truth: String,
61}
62
63/// Dataset for RAGAS compatible JSON, indexed by question
64#[derive(Debug, Clone)]
65pub struct EvaluationDataSet(HashMap<String, EvaluationData>);
66
67impl Ragas {
68    /// Builds a new Ragas evaluator from a list of questions or a list of tuples with questions and
69    /// ground truths. You can also call `parse` to load a dataset from a JSON string.
70    pub fn from_prepared_questions(questions: impl Into<EvaluationDataSet>) -> Self {
71        Ragas {
72            dataset: Arc::new(RwLock::new(questions.into())),
73        }
74    }
75
76    pub async fn questions(&self) -> Vec<Query<states::Pending>> {
77        self.dataset.read().await.0.keys().map(Into::into).collect()
78    }
79
80    /// Records the current answers as ground truths in the dataset
81    pub async fn record_answers_as_ground_truth(&self) {
82        self.dataset.write().await.record_answers_as_ground_truth();
83    }
84
85    /// Outputs the dataset as a JSON string compatible with RAGAS
86    pub async fn to_json(&self) -> String {
87        self.dataset.read().await.to_json()
88    }
89}
90
91#[async_trait]
92impl EvaluateQuery for Ragas {
93    #[tracing::instrument(skip_all)]
94    async fn evaluate(&self, query: QueryEvaluation) -> Result<()> {
95        let mut dataset = self.dataset.write().await;
96        dataset.upsert_evaluation(&query)
97    }
98}
99
100impl EvaluationDataSet {
101    pub(crate) fn record_answers_as_ground_truth(&mut self) {
102        for data in self.0.values_mut() {
103            data.ground_truth.clone_from(&data.answer);
104        }
105    }
106
107    pub(crate) fn upsert_evaluation(&mut self, query: &QueryEvaluation) -> Result<()> {
108        match query {
109            QueryEvaluation::RetrieveDocuments(query) => self.upsert_retrieved_documents(query),
110            QueryEvaluation::AnswerQuery(query) => self.upsert_answer(query),
111        }
112    }
113
114    // For each upsort, check if it exists and update it, or return an error
115    fn upsert_retrieved_documents(&mut self, query: &Query<states::Retrieved>) -> Result<()> {
116        let question = query.original();
117        let data = self
118            .0
119            .get_mut(question)
120            .ok_or_else(|| anyhow::anyhow!("Question not found"))?;
121
122        data.contexts = query
123            .documents()
124            .iter()
125            .map(|d| d.content().to_string())
126            .collect::<Vec<_>>();
127        Ok(())
128    }
129
130    fn upsert_answer(&mut self, query: &Query<states::Answered>) -> Result<()> {
131        let question = query.original();
132        let data = self
133            .0
134            .get_mut(question)
135            .ok_or_else(|| anyhow::anyhow!("Question not found"))?;
136
137        data.answer = query.answer().to_string();
138
139        Ok(())
140    }
141
142    /// Outputs json for ragas
143    ///
144    /// # Format
145    ///
146    /// ```json
147    /// [
148    ///   {
149    ///   "question": "What is the capital of France?",
150    ///   "answer": "Paris",
151    ///   "contexts": ["Paris is the capital of France"],
152    ///   "ground_truth": "Paris"
153    ///   },
154    ///   {
155    ///   "question": "What is the capital of France?",
156    ///   "answer": "Paris",
157    ///   "contexts": ["Paris is the capital of France"],
158    ///   "ground_truth": "Paris"
159    ///   }
160    /// ]
161    /// ```
162    pub(crate) fn to_json(&self) -> String {
163        json!(self.0.values().collect::<Vec<_>>()).to_string()
164    }
165}
166
167// Can just do a list of questions leaving ground truth, answers, contexts empty
168impl From<Vec<String>> for EvaluationDataSet {
169    fn from(val: Vec<String>) -> Self {
170        EvaluationDataSet(
171            val.into_iter()
172                .map(|question| {
173                    (
174                        question.clone(),
175                        EvaluationData {
176                            question,
177                            ..EvaluationData::default()
178                        },
179                    )
180                })
181                .collect(),
182        )
183    }
184}
185
186impl From<&[String]> for EvaluationDataSet {
187    fn from(val: &[String]) -> Self {
188        EvaluationDataSet(
189            val.iter()
190                .map(|question| {
191                    (
192                        question.to_string(),
193                        EvaluationData {
194                            question: question.to_string(),
195                            ..EvaluationData::default()
196                        },
197                    )
198                })
199                .collect(),
200        )
201    }
202}
203
204// Can take a list of tuples for questions and ground truths
205impl From<Vec<(String, String)>> for EvaluationDataSet {
206    fn from(val: Vec<(String, String)>) -> Self {
207        EvaluationDataSet(
208            val.into_iter()
209                .map(|(question, ground_truth)| {
210                    (
211                        question.clone(),
212                        EvaluationData {
213                            question,
214                            ground_truth,
215                            ..EvaluationData::default()
216                        },
217                    )
218                })
219                .collect(),
220        )
221    }
222}
223
224/// Parse an existing dataset from a JSON string
225impl FromStr for EvaluationDataSet {
226    type Err = serde_json::Error;
227
228    fn from_str(val: &str) -> std::prelude::v1::Result<Self, Self::Err> {
229        let data: Vec<EvaluationData> = serde_json::from_str(val)?;
230        Ok(EvaluationDataSet(
231            data.into_iter()
232                .map(|data| (data.question.clone(), data))
233                .collect(),
234        ))
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use std::sync::Arc;
242    use swiftide_core::querying::{Query, QueryEvaluation};
243    use tokio::sync::RwLock;
244
245    #[tokio::test]
246    async fn test_ragas_from_prepared_questions() {
247        let questions = vec!["What is Rust?".to_string(), "What is Tokio?".to_string()];
248        let ragas = Ragas::from_prepared_questions(questions.clone());
249
250        let stored_questions = ragas.questions().await;
251        assert_eq!(stored_questions.len(), questions.len());
252
253        for question in questions {
254            assert!(stored_questions.iter().any(|q| q.original() == question));
255        }
256    }
257
258    #[tokio::test]
259    async fn test_ragas_record_answers_as_ground_truth() {
260        let dataset = Arc::new(RwLock::new(EvaluationDataSet::from(vec![(
261            "What is Rust?".to_string(),
262            "A programming language".to_string(),
263        )])));
264        let ragas = Ragas {
265            dataset: dataset.clone(),
266        };
267
268        {
269            let mut lock = dataset.write().await;
270            let data = lock.0.get_mut("What is Rust?").unwrap();
271            data.answer = "A systems programming language".to_string();
272        }
273
274        ragas.record_answers_as_ground_truth().await;
275
276        let updated_data = ragas.dataset.read().await;
277        let data = updated_data.0.get("What is Rust?").unwrap();
278        assert_eq!(data.ground_truth, "A systems programming language");
279    }
280
281    #[tokio::test]
282    async fn test_ragas_to_json() {
283        let dataset = EvaluationDataSet::from(vec![(
284            "What is Rust?".to_string(),
285            "A programming language".to_string(),
286        )]);
287        let ragas = Ragas {
288            dataset: Arc::new(RwLock::new(dataset)),
289        };
290
291        let json_output = ragas.to_json().await;
292        let expected_json = "[{\"answer\":\"\",\"contexts\":[],\"ground_truth\":\"A programming language\",\"question\":\"What is Rust?\"}]";
293        assert_eq!(json_output, expected_json);
294    }
295
296    #[tokio::test]
297    async fn test_evaluate_query_upsert_retrieved_documents() {
298        let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
299        let ragas = Ragas {
300            dataset: Arc::new(RwLock::new(dataset.clone())),
301        };
302
303        let query = Query::builder()
304            .original("What is Rust?")
305            .documents(vec!["Rust is a language".into()])
306            .build()
307            .unwrap();
308        let evaluation = QueryEvaluation::RetrieveDocuments(query.clone());
309
310        ragas.evaluate(evaluation).await.unwrap();
311
312        let updated_data = ragas.dataset.read().await;
313        let data = updated_data.0.get("What is Rust?").unwrap();
314        assert_eq!(data.contexts, vec!["Rust is a language"]);
315    }
316
317    #[tokio::test]
318    async fn test_evaluate_query_upsert_answer() {
319        let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
320        let ragas = Ragas {
321            dataset: Arc::new(RwLock::new(dataset.clone())),
322        };
323
324        let query = Query::builder()
325            .original("What is Rust?")
326            .current("A systems programming language")
327            .build()
328            .unwrap();
329        let evaluation = QueryEvaluation::AnswerQuery(query.clone());
330
331        ragas.evaluate(evaluation).await.unwrap();
332
333        let updated_data = ragas.dataset.read().await;
334        let data = updated_data.0.get("What is Rust?").unwrap();
335        assert_eq!(data.answer, "A systems programming language");
336    }
337
338    #[tokio::test]
339    async fn test_evaluation_dataset_record_answers_as_ground_truth() {
340        let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
341        let data = dataset.0.get_mut("What is Rust?").unwrap();
342        data.answer = "A programming language".to_string();
343
344        dataset.record_answers_as_ground_truth();
345
346        let data = dataset.0.get("What is Rust?").unwrap();
347        assert_eq!(data.ground_truth, "A programming language");
348    }
349
350    #[tokio::test]
351    async fn test_evaluation_dataset_to_json() {
352        let dataset = EvaluationDataSet::from(vec![(
353            "What is Rust?".to_string(),
354            "A programming language".to_string(),
355        )]);
356
357        let json_output = dataset.to_json();
358        let expected_json = "[{\"answer\":\"\",\"contexts\":[],\"ground_truth\":\"A programming language\",\"question\":\"What is Rust?\"}]";
359        assert_eq!(json_output, expected_json);
360    }
361
362    #[tokio::test]
363    async fn test_evaluation_dataset_upsert_retrieved_documents() {
364        let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
365
366        let query = Query::builder()
367            .original("What is Rust?")
368            .documents(vec!["Rust is a language".into()])
369            .build()
370            .unwrap();
371        dataset
372            .upsert_evaluation(&QueryEvaluation::RetrieveDocuments(query.clone()))
373            .unwrap();
374
375        let data = dataset.0.get("What is Rust?").unwrap();
376        assert_eq!(data.contexts, vec!["Rust is a language"]);
377    }
378
379    #[tokio::test]
380    async fn test_evaluation_dataset_upsert_answer() {
381        let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
382
383        let query = Query::builder()
384            .original("What is Rust?")
385            .current("A systems programming language")
386            .build()
387            .unwrap();
388        dataset
389            .upsert_evaluation(&QueryEvaluation::AnswerQuery(query.clone()))
390            .unwrap();
391
392        let data = dataset.0.get("What is Rust?").unwrap();
393        assert_eq!(data.answer, "A systems programming language");
394    }
395}