swiftide_query/evaluators/
ragas.rs

1/*!
2The Ragas evaluator allows you to export a RAGAS compatible JSON dataset.
3
4RAGAS requires a ground truth to compare to. You can either record the answers for an initial dataset, or provide the ground truth yourself.
5
6Refer to the ragas documentation on how to use the dataset or take a look at a more involved
7example at [swiftide-tutorials](https://github.com/bosun-ai/swiftide-tutorial).
8
9# Example
10
11```ignore
12# use swiftide_query::*;
13# use anyhow::{Result, Context};
14# #[tokio::main]
15# async fn main() -> anyhow::Result<()> {
16
17let openai = swiftide::integrations::openai::OpenAi::default();
18let qdrant = swiftide::integrations::qdrant::Qdrant::default();
19
20let ragas = evaluators::ragas::Ragas::from_prepared_questions(questions);
21
22let pipeline = query::Pipeline::default()
23    .evaluate_with(ragas.clone())
24    .then_transform_query(query_transformers::GenerateSubquestions::from_client(openai.clone()))
25    .then_transform_query(query_transformers::Embed::from_client(
26        openai.clone(),
27    ))
28    .then_retrieve(qdrant.clone())
29    .then_answer(answers::Simple::from_client(openai.clone()));
30
31pipeline.query_all(ragas.questions().await).await.unwrap();
32
33std::fs::write("output.json", ragas.to_json().await).unwrap();
34# Ok(())
35# }
36*/
37use anyhow::Result;
38use async_trait::async_trait;
39use serde::{Deserialize, Serialize};
40use serde_json::json;
41use std::{collections::HashMap, str::FromStr, sync::Arc};
42use tokio::sync::RwLock;
43
44use swiftide_core::{
45    querying::{states, Query, QueryEvaluation},
46    EvaluateQuery,
47};
48
49/// Ragas evaluator to be used in a pipeline
50#[derive(Debug, Clone)]
51pub struct Ragas {
52    dataset: Arc<RwLock<EvaluationDataSet>>,
53}
54
55/// Row structure for RAGAS compatible JSON
56#[derive(Debug, Clone, Default, Serialize, Deserialize)]
57pub struct EvaluationData {
58    question: String,
59    answer: String,
60    contexts: Vec<String>,
61    ground_truth: String,
62}
63
64/// Dataset for RAGAS compatible JSON, indexed by question
65#[derive(Debug, Clone)]
66pub struct EvaluationDataSet(HashMap<String, EvaluationData>);
67
68impl Ragas {
69    /// Builds a new Ragas evaluator from a list of questions or a list of tuples with questions and
70    /// ground truths. You can also call `parse` to load a dataset from a JSON string.
71    pub fn from_prepared_questions(questions: impl Into<EvaluationDataSet>) -> Self {
72        Ragas {
73            dataset: Arc::new(RwLock::new(questions.into())),
74        }
75    }
76
77    pub async fn questions(&self) -> Vec<Query<states::Pending>> {
78        self.dataset.read().await.0.keys().map(Into::into).collect()
79    }
80
81    /// Records the current answers as ground truths in the dataset
82    pub async fn record_answers_as_ground_truth(&self) {
83        self.dataset.write().await.record_answers_as_ground_truth();
84    }
85
86    /// Outputs the dataset as a JSON string compatible with RAGAS
87    pub async fn to_json(&self) -> String {
88        self.dataset.read().await.to_json()
89    }
90}
91
92#[async_trait]
93impl EvaluateQuery for Ragas {
94    #[tracing::instrument(skip_all)]
95    async fn evaluate(&self, query: QueryEvaluation) -> Result<()> {
96        let mut dataset = self.dataset.write().await;
97        dataset.upsert_evaluation(&query)
98    }
99}
100
101impl EvaluationDataSet {
102    pub(crate) fn record_answers_as_ground_truth(&mut self) {
103        for data in self.0.values_mut() {
104            data.ground_truth.clone_from(&data.answer);
105        }
106    }
107
108    pub(crate) fn upsert_evaluation(&mut self, query: &QueryEvaluation) -> Result<()> {
109        match query {
110            QueryEvaluation::RetrieveDocuments(query) => self.upsert_retrieved_documents(query),
111            QueryEvaluation::AnswerQuery(query) => self.upsert_answer(query),
112        }
113    }
114
115    // For each upsort, check if it exists and update it, or return an error
116    fn upsert_retrieved_documents(&mut self, query: &Query<states::Retrieved>) -> Result<()> {
117        let question = query.original();
118        let data = self
119            .0
120            .get_mut(question)
121            .ok_or_else(|| anyhow::anyhow!("Question not found"))?;
122
123        data.contexts = query
124            .documents()
125            .iter()
126            .map(|d| d.content().to_string())
127            .collect::<Vec<_>>();
128        Ok(())
129    }
130
131    fn upsert_answer(&mut self, query: &Query<states::Answered>) -> Result<()> {
132        let question = query.original();
133        let data = self
134            .0
135            .get_mut(question)
136            .ok_or_else(|| anyhow::anyhow!("Question not found"))?;
137
138        data.answer = query.answer().to_string();
139
140        Ok(())
141    }
142
143    /// Outputs json for ragas
144    ///
145    /// # Format
146    ///
147    /// ```json
148    /// [
149    ///   {
150    ///   "question": "What is the capital of France?",
151    ///   "answer": "Paris",
152    ///   "contexts": ["Paris is the capital of France"],
153    ///   "ground_truth": "Paris"
154    ///   },
155    ///   {
156    ///   "question": "What is the capital of France?",
157    ///   "answer": "Paris",
158    ///   "contexts": ["Paris is the capital of France"],
159    ///   "ground_truth": "Paris"
160    ///   }
161    /// ]
162    /// ```
163    pub(crate) fn to_json(&self) -> String {
164        json!(self.0.values().collect::<Vec<_>>()).to_string()
165    }
166}
167
168// Can just do a list of questions leaving ground truth, answers, contexts empty
169impl From<Vec<String>> for EvaluationDataSet {
170    fn from(val: Vec<String>) -> Self {
171        EvaluationDataSet(
172            val.into_iter()
173                .map(|question| {
174                    (
175                        question.clone(),
176                        EvaluationData {
177                            question,
178                            ..EvaluationData::default()
179                        },
180                    )
181                })
182                .collect(),
183        )
184    }
185}
186
187impl From<&[String]> for EvaluationDataSet {
188    fn from(val: &[String]) -> Self {
189        EvaluationDataSet(
190            val.iter()
191                .map(|question| {
192                    (
193                        question.to_string(),
194                        EvaluationData {
195                            question: question.to_string(),
196                            ..EvaluationData::default()
197                        },
198                    )
199                })
200                .collect(),
201        )
202    }
203}
204
205// Can take a list of tuples for questions and ground truths
206impl From<Vec<(String, String)>> for EvaluationDataSet {
207    fn from(val: Vec<(String, String)>) -> Self {
208        EvaluationDataSet(
209            val.into_iter()
210                .map(|(question, ground_truth)| {
211                    (
212                        question.clone(),
213                        EvaluationData {
214                            question,
215                            ground_truth,
216                            ..EvaluationData::default()
217                        },
218                    )
219                })
220                .collect(),
221        )
222    }
223}
224
225/// Parse an existing dataset from a JSON string
226impl FromStr for EvaluationDataSet {
227    type Err = serde_json::Error;
228
229    fn from_str(val: &str) -> std::prelude::v1::Result<Self, Self::Err> {
230        let data: Vec<EvaluationData> = serde_json::from_str(val)?;
231        Ok(EvaluationDataSet(
232            data.into_iter()
233                .map(|data| (data.question.clone(), data))
234                .collect(),
235        ))
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use std::sync::Arc;
243    use swiftide_core::querying::{Query, QueryEvaluation};
244    use tokio::sync::RwLock;
245
246    #[tokio::test]
247    async fn test_ragas_from_prepared_questions() {
248        let questions = vec!["What is Rust?".to_string(), "What is Tokio?".to_string()];
249        let ragas = Ragas::from_prepared_questions(questions.clone());
250
251        let stored_questions = ragas.questions().await;
252        assert_eq!(stored_questions.len(), questions.len());
253
254        for question in questions {
255            assert!(stored_questions.iter().any(|q| q.original() == question));
256        }
257    }
258
259    #[tokio::test]
260    async fn test_ragas_record_answers_as_ground_truth() {
261        let dataset = Arc::new(RwLock::new(EvaluationDataSet::from(vec![(
262            "What is Rust?".to_string(),
263            "A programming language".to_string(),
264        )])));
265        let ragas = Ragas {
266            dataset: dataset.clone(),
267        };
268
269        {
270            let mut lock = dataset.write().await;
271            let data = lock.0.get_mut("What is Rust?").unwrap();
272            data.answer = "A systems programming language".to_string();
273        }
274
275        ragas.record_answers_as_ground_truth().await;
276
277        let updated_data = ragas.dataset.read().await;
278        let data = updated_data.0.get("What is Rust?").unwrap();
279        assert_eq!(data.ground_truth, "A systems programming language");
280    }
281
282    #[tokio::test]
283    async fn test_ragas_to_json() {
284        let dataset = EvaluationDataSet::from(vec![(
285            "What is Rust?".to_string(),
286            "A programming language".to_string(),
287        )]);
288        let ragas = Ragas {
289            dataset: Arc::new(RwLock::new(dataset)),
290        };
291
292        let json_output = ragas.to_json().await;
293        let expected_json = "[{\"answer\":\"\",\"contexts\":[],\"ground_truth\":\"A programming language\",\"question\":\"What is Rust?\"}]";
294        assert_eq!(json_output, expected_json);
295    }
296
297    #[tokio::test]
298    async fn test_evaluate_query_upsert_retrieved_documents() {
299        let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
300        let ragas = Ragas {
301            dataset: Arc::new(RwLock::new(dataset.clone())),
302        };
303
304        let query = Query::builder()
305            .original("What is Rust?")
306            .documents(vec!["Rust is a language".into()])
307            .build()
308            .unwrap();
309        let evaluation = QueryEvaluation::RetrieveDocuments(query.clone());
310
311        ragas.evaluate(evaluation).await.unwrap();
312
313        let updated_data = ragas.dataset.read().await;
314        let data = updated_data.0.get("What is Rust?").unwrap();
315        assert_eq!(data.contexts, vec!["Rust is a language"]);
316    }
317
318    #[tokio::test]
319    async fn test_evaluate_query_upsert_answer() {
320        let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
321        let ragas = Ragas {
322            dataset: Arc::new(RwLock::new(dataset.clone())),
323        };
324
325        let query = Query::builder()
326            .original("What is Rust?")
327            .current("A systems programming language")
328            .build()
329            .unwrap();
330        let evaluation = QueryEvaluation::AnswerQuery(query.clone());
331
332        ragas.evaluate(evaluation).await.unwrap();
333
334        let updated_data = ragas.dataset.read().await;
335        let data = updated_data.0.get("What is Rust?").unwrap();
336        assert_eq!(data.answer, "A systems programming language");
337    }
338
339    #[tokio::test]
340    async fn test_evaluation_dataset_record_answers_as_ground_truth() {
341        let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
342        let data = dataset.0.get_mut("What is Rust?").unwrap();
343        data.answer = "A programming language".to_string();
344
345        dataset.record_answers_as_ground_truth();
346
347        let data = dataset.0.get("What is Rust?").unwrap();
348        assert_eq!(data.ground_truth, "A programming language");
349    }
350
351    #[tokio::test]
352    async fn test_evaluation_dataset_to_json() {
353        let dataset = EvaluationDataSet::from(vec![(
354            "What is Rust?".to_string(),
355            "A programming language".to_string(),
356        )]);
357
358        let json_output = dataset.to_json();
359        let expected_json = "[{\"answer\":\"\",\"contexts\":[],\"ground_truth\":\"A programming language\",\"question\":\"What is Rust?\"}]";
360        assert_eq!(json_output, expected_json);
361    }
362
363    #[tokio::test]
364    async fn test_evaluation_dataset_upsert_retrieved_documents() {
365        let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
366
367        let query = Query::builder()
368            .original("What is Rust?")
369            .documents(vec!["Rust is a language".into()])
370            .build()
371            .unwrap();
372        dataset
373            .upsert_evaluation(&QueryEvaluation::RetrieveDocuments(query.clone()))
374            .unwrap();
375
376        let data = dataset.0.get("What is Rust?").unwrap();
377        assert_eq!(data.contexts, vec!["Rust is a language"]);
378    }
379
380    #[tokio::test]
381    async fn test_evaluation_dataset_upsert_answer() {
382        let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
383
384        let query = Query::builder()
385            .original("What is Rust?")
386            .current("A systems programming language")
387            .build()
388            .unwrap();
389        dataset
390            .upsert_evaluation(&QueryEvaluation::AnswerQuery(query.clone()))
391            .unwrap();
392
393        let data = dataset.0.get("What is Rust?").unwrap();
394        assert_eq!(data.answer, "A systems programming language");
395    }
396}