use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{collections::HashMap, str::FromStr, sync::Arc};
use tokio::sync::RwLock;
use swiftide_core::{
querying::{states, Query, QueryEvaluation},
EvaluateQuery,
};
#[derive(Debug, Clone)]
pub struct Ragas {
dataset: Arc<RwLock<EvaluationDataSet>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EvaluationData {
question: String,
answer: String,
contexts: Vec<String>,
ground_truth: String,
}
#[derive(Debug, Clone)]
pub struct EvaluationDataSet(HashMap<String, EvaluationData>);
impl Ragas {
pub fn from_prepared_questions(questions: impl Into<EvaluationDataSet>) -> Self {
Ragas {
dataset: Arc::new(RwLock::new(questions.into())),
}
}
pub async fn questions(&self) -> Vec<Query<states::Pending>> {
self.dataset.read().await.0.keys().map(Into::into).collect()
}
pub async fn record_answers_as_ground_truth(&self) {
self.dataset.write().await.record_answers_as_ground_truth();
}
pub async fn to_json(&self) -> String {
self.dataset.read().await.to_json()
}
}
#[async_trait]
impl EvaluateQuery for Ragas {
#[tracing::instrument(skip_all)]
async fn evaluate(&self, query: QueryEvaluation) -> Result<()> {
let mut dataset = self.dataset.write().await;
dataset.upsert_evaluation(&query)
}
}
impl EvaluationDataSet {
pub(crate) fn record_answers_as_ground_truth(&mut self) {
for data in self.0.values_mut() {
data.ground_truth.clone_from(&data.answer);
}
}
pub(crate) fn upsert_evaluation(&mut self, query: &QueryEvaluation) -> Result<()> {
match query {
QueryEvaluation::RetrieveDocuments(query) => self.upsert_retrieved_documents(query),
QueryEvaluation::AnswerQuery(query) => self.upsert_answer(query),
}
}
fn upsert_retrieved_documents(&mut self, query: &Query<states::Retrieved>) -> Result<()> {
let question = query.original();
let data = self
.0
.get_mut(question)
.ok_or_else(|| anyhow::anyhow!("Question not found"))?;
data.contexts = query.documents().to_vec();
Ok(())
}
fn upsert_answer(&mut self, query: &Query<states::Answered>) -> Result<()> {
let question = query.original();
let data = self
.0
.get_mut(question)
.ok_or_else(|| anyhow::anyhow!("Question not found"))?;
data.answer = query.answer().to_string();
Ok(())
}
pub(crate) fn to_json(&self) -> String {
json!(self.0.values().collect::<Vec<_>>()).to_string()
}
}
impl From<Vec<String>> for EvaluationDataSet {
fn from(val: Vec<String>) -> Self {
EvaluationDataSet(
val.into_iter()
.map(|question| {
(
question.clone(),
EvaluationData {
question,
..EvaluationData::default()
},
)
})
.collect(),
)
}
}
impl From<&[String]> for EvaluationDataSet {
fn from(val: &[String]) -> Self {
EvaluationDataSet(
val.iter()
.map(|question| {
(
question.to_string(),
EvaluationData {
question: question.to_string(),
..EvaluationData::default()
},
)
})
.collect(),
)
}
}
impl From<Vec<(String, String)>> for EvaluationDataSet {
fn from(val: Vec<(String, String)>) -> Self {
EvaluationDataSet(
val.into_iter()
.map(|(question, ground_truth)| {
(
question.clone(),
EvaluationData {
question,
ground_truth,
..EvaluationData::default()
},
)
})
.collect(),
)
}
}
impl FromStr for EvaluationDataSet {
type Err = serde_json::Error;
fn from_str(val: &str) -> std::prelude::v1::Result<Self, Self::Err> {
let data: Vec<EvaluationData> = serde_json::from_str(val)?;
Ok(EvaluationDataSet(
data.into_iter()
.map(|data| (data.question.clone(), data))
.collect(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use swiftide_core::querying::{states, Query, QueryEvaluation};
use tokio::sync::RwLock;
#[tokio::test]
async fn test_ragas_from_prepared_questions() {
let questions = vec!["What is Rust?".to_string(), "What is Tokio?".to_string()];
let ragas = Ragas::from_prepared_questions(questions.clone());
let stored_questions = ragas.questions().await;
assert_eq!(stored_questions.len(), questions.len());
for question in questions {
assert!(stored_questions.iter().any(|q| q.original() == question));
}
}
#[tokio::test]
async fn test_ragas_record_answers_as_ground_truth() {
let dataset = Arc::new(RwLock::new(EvaluationDataSet::from(vec![(
"What is Rust?".to_string(),
"A programming language".to_string(),
)])));
let ragas = Ragas {
dataset: dataset.clone(),
};
{
let mut lock = dataset.write().await;
let data = lock.0.get_mut("What is Rust?").unwrap();
data.answer = "A systems programming language".to_string();
}
ragas.record_answers_as_ground_truth().await;
let updated_data = ragas.dataset.read().await;
let data = updated_data.0.get("What is Rust?").unwrap();
assert_eq!(data.ground_truth, "A systems programming language");
}
#[tokio::test]
async fn test_ragas_to_json() {
let dataset = EvaluationDataSet::from(vec![(
"What is Rust?".to_string(),
"A programming language".to_string(),
)]);
let ragas = Ragas {
dataset: Arc::new(RwLock::new(dataset)),
};
let json_output = ragas.to_json().await;
let expected_json = "[{\"answer\":\"\",\"contexts\":[],\"ground_truth\":\"A programming language\",\"question\":\"What is Rust?\"}]";
assert_eq!(json_output, expected_json);
}
#[tokio::test]
async fn test_evaluate_query_upsert_retrieved_documents() {
let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
let ragas = Ragas {
dataset: Arc::new(RwLock::new(dataset.clone())),
};
let query = Query::builder()
.original("What is Rust?")
.state(
states::RetrievedBuilder::default()
.documents(vec!["Rust is a language".to_string()])
.build()
.unwrap(),
)
.build()
.unwrap();
let evaluation = QueryEvaluation::RetrieveDocuments(query.clone());
ragas.evaluate(evaluation).await.unwrap();
let updated_data = ragas.dataset.read().await;
let data = updated_data.0.get("What is Rust?").unwrap();
assert_eq!(data.contexts, vec!["Rust is a language"]);
}
#[tokio::test]
async fn test_evaluate_query_upsert_answer() {
let dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
let ragas = Ragas {
dataset: Arc::new(RwLock::new(dataset.clone())),
};
let query = Query::builder()
.original("What is Rust?")
.state(
states::AnsweredBuilder::default()
.answer("A systems programming language")
.build()
.unwrap(),
)
.build()
.unwrap();
let evaluation = QueryEvaluation::AnswerQuery(query.clone());
ragas.evaluate(evaluation).await.unwrap();
let updated_data = ragas.dataset.read().await;
let data = updated_data.0.get("What is Rust?").unwrap();
assert_eq!(data.answer, "A systems programming language");
}
#[tokio::test]
async fn test_evaluation_dataset_record_answers_as_ground_truth() {
let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
let data = dataset.0.get_mut("What is Rust?").unwrap();
data.answer = "A programming language".to_string();
dataset.record_answers_as_ground_truth();
let data = dataset.0.get("What is Rust?").unwrap();
assert_eq!(data.ground_truth, "A programming language");
}
#[tokio::test]
async fn test_evaluation_dataset_to_json() {
let dataset = EvaluationDataSet::from(vec![(
"What is Rust?".to_string(),
"A programming language".to_string(),
)]);
let json_output = dataset.to_json();
let expected_json = "[{\"answer\":\"\",\"contexts\":[],\"ground_truth\":\"A programming language\",\"question\":\"What is Rust?\"}]";
assert_eq!(json_output, expected_json);
}
#[tokio::test]
async fn test_evaluation_dataset_upsert_retrieved_documents() {
let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
let query = Query::builder()
.original("What is Rust?")
.state(
states::RetrievedBuilder::default()
.documents(vec!["Rust is a language".to_string()])
.build()
.unwrap(),
)
.build()
.unwrap();
dataset
.upsert_evaluation(&QueryEvaluation::RetrieveDocuments(query.clone()))
.unwrap();
let data = dataset.0.get("What is Rust?").unwrap();
assert_eq!(data.contexts, vec!["Rust is a language"]);
}
#[tokio::test]
async fn test_evaluation_dataset_upsert_answer() {
let mut dataset = EvaluationDataSet::from(vec!["What is Rust?".to_string()]);
let query = Query::builder()
.original("What is Rust?")
.state(
states::AnsweredBuilder::default()
.answer("A systems programming language")
.build()
.unwrap(),
)
.build()
.unwrap();
dataset
.upsert_evaluation(&QueryEvaluation::AnswerQuery(query.clone()))
.unwrap();
let data = dataset.0.get("What is Rust?").unwrap();
assert_eq!(data.answer, "A systems programming language");
}
}