swiftide_query/evaluators/
ragas.rs

1//! The Ragas evaluator allows you to export a RAGAS compatible JSON dataset.
2//!
3//! RAGAS requires a ground truth to compare to. You can either record the answers for an initial
4//! dataset, or provide the ground truth yourself.
5//!
6//! Refer to the ragas documentation on how to use the dataset or take a look at a more involved
7//! example at [swiftide-tutorials](https://github.com/bosun-ai/swiftide-tutorial).
8//!
9//! # Example
10//!
11//! ```ignore
12//! # use swiftide_query::*;
13//! # use anyhow::{Result, Context};
14//! # #[tokio::main]
15//! # async fn main() -> anyhow::Result<()> {
16//!
17//! let openai = swiftide::integrations::openai::OpenAi::default();
18//! let qdrant = swiftide::integrations::qdrant::Qdrant::default();
19//!
20//! let ragas = evaluators::ragas::Ragas::from_prepared_questions(questions);
21//!
22//! let pipeline = query::Pipeline::default()
23//! .evaluate_with(ragas.clone())
24//! .then_transform_query(query_transformers::GenerateSubquestions::from_client(openai.clone()))
25//! .then_transform_query(query_transformers::Embed::from_client(
26//! openai.clone(),
27//! ))
28//! .then_retrieve(qdrant.clone())
29//! .then_answer(answers::Simple::from_client(openai.clone()));
30//!
31//! pipeline.query_all(ragas.questions().await).await.unwrap();
32//!
33//! std::fs::write("output.json", ragas.to_json().await).unwrap();
34//! # Ok(())
35//! # }
36use anyhow::Result;
37use async_trait::async_trait;
38use serde::{Deserialize, Serialize};
39use serde_json::json;
40use std::{collections::HashMap, str::FromStr, sync::Arc};
41use tokio::sync::RwLock;
42
43use swiftide_core::{
44    querying::{states, Query, QueryEvaluation},
45    EvaluateQuery,
46};
47
48/// Ragas evaluator to be used in a pipeline
49#[derive(Debug, Clone)]
50pub struct Ragas {
51    dataset: Arc<RwLock<EvaluationDataSet>>,
52}
53
54/// Row structure for RAGAS compatible JSON
55#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56pub struct EvaluationData {
57    question: String,
58    answer: String,
59    contexts: Vec<String>,
60    ground_truth: String,
61}
62
63/// Dataset for RAGAS compatible JSON, indexed by question
64#[derive(Debug, Clone)]
65pub struct EvaluationDataSet(HashMap<String, EvaluationData>);
66
67impl Ragas {
68    /// Builds a new Ragas evaluator from a list of questions or a list of tuples with questions and
69    /// ground truths. You can also call `parse` to load a dataset from a JSON string.
70    pub fn from_prepared_questions(questions: impl Into<EvaluationDataSet>) -> Self {
71        Ragas {
72            dataset: Arc::new(RwLock::new(questions.into())),
73        }
74    }
75
76    pub async fn questions(&self) -> Vec<Query<states::Pending>> {
77        self.dataset.read().await.0.keys().map(Into::into).collect()
78    }
79
80    /// Records the current answers as ground truths in the dataset
81    pub async fn record_answers_as_ground_truth(&self) {
82        self.dataset.write().await.record_answers_as_ground_truth();
83    }
84
85    /// Outputs the dataset as a JSON string compatible with RAGAS
86    pub async fn to_json(&self) -> String {
87        self.dataset.read().await.to_json()
88    }
89}
90
91#[async_trait]
92impl EvaluateQuery for Ragas {
93    #[tracing::instrument(skip_all)]
94    async fn evaluate(&self, query: QueryEvaluation) -> Result<()> {
95        let mut dataset = self.dataset.write().await;
96        dataset.upsert_evaluation(&query)
97    }
98}
99
100impl EvaluationDataSet {
101    pub(crate) fn record_answers_as_ground_truth(&mut self) {
102        for data in self.0.values_mut() {
103            data.ground_truth.clone_from(&data.answer);
104        }
105    }
106
107    pub(crate) fn upsert_evaluation(&mut self, query: &QueryEvaluation) -> Result<()> {
108        match query {
109            QueryEvaluation::RetrieveDocuments(query) => self.upsert_retrieved_documents(query),
110            QueryEvaluation::AnswerQuery(query) => self.upsert_answer(query),
111        }
112    }
113
114    // For each upsort, check if it exists and update it, or return an error
115    fn upsert_retrieved_documents(&mut self, query: &Query<states::Retrieved>) -> Result<()> {
116        let question = query.original();
117        let data = self
118            .0
119            .get_mut(question)
120            .ok_or_else(|| anyhow::anyhow!("Question not found"))?;
121
122        data.contexts = query
123            .documents()
124            .iter()
125            .map(|d| d.content().to_string())
126            .collect::<Vec<_>>();
127        Ok(())
128    }
129
130    fn upsert_answer(&mut self, query: &Query<states::Answered>) -> Result<()> {
131        let question = query.original();
132        let data = self
133            .0
134            .get_mut(question)
135            .ok_or_else(|| anyhow::anyhow!("Question not found"))?;
136
137        data.answer = query.answer().to_string();
138
139        Ok(())
140    }
141
142    /// Outputs json for ragas
143    ///
144    /// # Format
145    ///
146    /// ```json
147    /// [
148    ///   {
149    ///   "question": "What is the capital of France?",
150    ///   "answer": "Paris",
151    ///   "contexts": ["Paris is the capital of France"],
152    ///   "ground_truth": "Paris"
153    ///   },
154    ///   {
155    ///   "question": "What is the capital of France?",
156    ///   "answer": "Paris",
157    ///   "contexts": ["Paris is the capital of France"],
158    ///   "ground_truth": "Paris"
159    ///   }
160    /// ]
161    /// ```
162    pub(crate) fn to_json(&self) -> String {
163        let json_value = json!(self.0.values().collect::<Vec<_>>());
164        serde_json::to_string_pretty(&json_value).unwrap_or_else(|_| json_value.to_string())
165    }
166}
167
168// Can just do a list of questions leaving ground truth, answers, contexts empty
169impl From<Vec<String>> for EvaluationDataSet {
170    fn from(val: Vec<String>) -> Self {
171        EvaluationDataSet(
172            val.into_iter()
173                .map(|question| {
174                    (
175                        question.clone(),
176                        EvaluationData {
177                            question,
178                            ..EvaluationData::default()
179                        },
180                    )
181                })
182                .collect(),
183        )
184    }
185}
186
187impl From<&[String]> for EvaluationDataSet {
188    fn from(val: &[String]) -> Self {
189        EvaluationDataSet(
190            val.iter()
191                .map(|question| {
192                    (
193                        question.to_string(),
194                        EvaluationData {
195                            question: question.to_string(),
196                            ..EvaluationData::default()
197                        },
198                    )
199                })
200                .collect(),
201        )
202    }
203}
204
205// Can take a list of tuples for questions and ground truths
206impl From<Vec<(String, String)>> for EvaluationDataSet {
207    fn from(val: Vec<(String, String)>) -> Self {
208        EvaluationDataSet(
209            val.into_iter()
210                .map(|(question, ground_truth)| {
211                    (
212                        question.clone(),
213                        EvaluationData {
214                            question,
215                            ground_truth,
216                            ..EvaluationData::default()
217                        },
218                    )
219                })
220                .collect(),
221        )
222    }
223}
224
225/// Parse an existing dataset from a JSON string
226impl FromStr for EvaluationDataSet {
227    type Err = serde_json::Error;
228
229    fn from_str(val: &str) -> std::prelude::v1::Result<Self, Self::Err> {
230        let data: Vec<EvaluationData> = serde_json::from_str(val)?;
231        Ok(EvaluationDataSet(
232            data.into_iter()
233                .map(|data| (data.question.clone(), data))
234                .collect(),
235        ))
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use std::sync::Arc;
243    use swiftide_core::querying::{Query, QueryEvaluation};
244    use tokio::sync::RwLock;
245
246    #[tokio::test]
247    async fn test_ragas_from_prepared_questions() {
248        let questions = vec!["What is Rust?".to_string(), "What is Tokio?".to_string()];
249        let ragas = Ragas::from_prepared_questions(questions.clone());
250
251        let stored_questions = ragas.questions().await;
252        assert_eq!(stored_questions.len(), questions.len());
253
254        for question in questions {
255            assert!(stored_questions.iter().any(|q| q.original() == question));
256        }
257    }
258
259    #[tokio::test]
260    async fn test_ragas_record_answers_as_ground_truth() {
261        let dataset = Arc::new(RwLock::new(EvaluationDataSet::from(vec![(
262            "What is Rust?".to_string(),
263            "A programming language".to_string(),
264        )])));
265        let ragas = Ragas {
266            dataset: dataset.clone(),
267        };
268
269        {
270            let mut lock = dataset.write().await;
271            let data = lock.0.get_mut("What is Rust?").unwrap();
272            data.answer = "A systems programming language".to_string();
273        }
274
275        ragas.record_answers_as_ground_truth().await;
276
277        let updated_data = ragas.dataset.read().await;
278        let data = updated_data.0.get("What is Rust?").unwrap();
279        assert_eq!(data.ground_truth, "A systems programming language");
280    }
281
282    #[tokio::test]
283    async fn test_ragas_to_json() {
284        let dataset = EvaluationDataSet::from(vec![(
285            "What is Rust?".to_string(),
286            "A programming language".to_string(),
287        )]);
288        let ragas = Ragas {
289            dataset: Arc::new(RwLock::new(dataset)),
290        };
291
292        let json_output = ragas.to_json().await;
293        let expected_json = "[\n  {\n    \"answer\": \"\",\n    \"contexts\": [],\n    \"ground_truth\": \"A programming language\",\n    \"question\": \"What is Rust?\"\n  }\n]";
294        assert_eq!(json_output, expected_json);
295    }
296
297    #[tokio::test]
298    async fn test_evaluate_query_upsert_retrieved_documents() {
299        let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
300        let ragas = Ragas {
301            dataset: Arc::new(RwLock::new(dataset.clone())),
302        };
303
304        let query = Query::builder()
305            .original("What is Rust?")
306            .documents(vec!["Rust is a language".into()])
307            .build()
308            .unwrap();
309        let evaluation = QueryEvaluation::RetrieveDocuments(query.clone());
310
311        ragas.evaluate(evaluation).await.unwrap();
312
313        let updated_data = ragas.dataset.read().await;
314        let data = updated_data.0.get("What is Rust?").unwrap();
315        assert_eq!(data.contexts, vec!["Rust is a language"]);
316    }
317
318    #[tokio::test]
319    async fn test_evaluate_query_upsert_answer() {
320        let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
321        let ragas = Ragas {
322            dataset: Arc::new(RwLock::new(dataset.clone())),
323        };
324
325        let query = Query::builder()
326            .original("What is Rust?")
327            .current("A systems programming language")
328            .build()
329            .unwrap();
330        let evaluation = QueryEvaluation::AnswerQuery(query.clone());
331
332        ragas.evaluate(evaluation).await.unwrap();
333
334        let updated_data = ragas.dataset.read().await;
335        let data = updated_data.0.get("What is Rust?").unwrap();
336        assert_eq!(data.answer, "A systems programming language");
337    }
338
339    #[tokio::test]
340    async fn test_evaluation_dataset_record_answers_as_ground_truth() {
341        let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
342        let data = dataset.0.get_mut("What is Rust?").unwrap();
343        data.answer = "A programming language".to_string();
344
345        dataset.record_answers_as_ground_truth();
346
347        let data = dataset.0.get("What is Rust?").unwrap();
348        assert_eq!(data.ground_truth, "A programming language");
349    }
350
351    #[tokio::test]
352    async fn test_evaluation_dataset_to_json() {
353        let dataset = EvaluationDataSet::from(vec![(
354            "What is Rust?".to_string(),
355            "A programming language".to_string(),
356        )]);
357
358        let json_output = dataset.to_json();
359        let expected_json = "[\n  {\n    \"answer\": \"\",\n    \"contexts\": [],\n    \"ground_truth\": \"A programming language\",\n    \"question\": \"What is Rust?\"\n  }\n]";
360        assert_eq!(json_output, expected_json);
361    }
362
363    #[tokio::test]
364    async fn test_evaluation_dataset_upsert_retrieved_documents() {
365        let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
366
367        let query = Query::builder()
368            .original("What is Rust?")
369            .documents(vec!["Rust is a language".into()])
370            .build()
371            .unwrap();
372        dataset
373            .upsert_evaluation(&QueryEvaluation::RetrieveDocuments(query.clone()))
374            .unwrap();
375
376        let data = dataset.0.get("What is Rust?").unwrap();
377        assert_eq!(data.contexts, vec!["Rust is a language"]);
378    }
379
380    #[tokio::test]
381    async fn test_evaluation_dataset_upsert_answer() {
382        let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
383
384        let query = Query::builder()
385            .original("What is Rust?")
386            .current("A systems programming language")
387            .build()
388            .unwrap();
389        dataset
390            .upsert_evaluation(&QueryEvaluation::AnswerQuery(query.clone()))
391            .unwrap();
392
393        let data = dataset.0.get("What is Rust?").unwrap();
394        assert_eq!(data.answer, "A systems programming language");
395    }
396}