use derive_builder::Builder;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use swiftide_core::{indexing::Node, prompt::PromptTemplate, SimplePrompt, Transformer};
pub const NAME: &str = "Questions and Answers (code)";
#[derive(Debug, Clone, Builder)]
#[builder(setter(into, strip_option))]
pub struct MetadataQACode {
#[builder(setter(custom))]
client: Arc<dyn SimplePrompt>,
#[builder(default = "default_prompt()")]
prompt_template: PromptTemplate,
#[builder(default = "5")]
num_questions: usize,
#[builder(default)]
concurrency: Option<usize>,
}
impl MetadataQACode {
pub fn builder() -> MetadataQACodeBuilder {
MetadataQACodeBuilder::default()
}
pub fn from_client(client: impl SimplePrompt + 'static) -> MetadataQACodeBuilder {
MetadataQACodeBuilder::default().client(client).to_owned()
}
pub fn new(client: impl SimplePrompt + 'static) -> Self {
Self {
client: Arc::new(client),
prompt_template: default_prompt(),
num_questions: 5,
concurrency: None,
}
}
#[must_use]
pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = Some(concurrency);
self
}
}
fn default_prompt() -> PromptTemplate {
include_str!("prompts/metadata_qa_code.prompt.md").into()
}
impl MetadataQACodeBuilder {
pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self {
self.client = Some(Arc::new(client));
self
}
}
#[async_trait]
impl Transformer for MetadataQACode {
#[tracing::instrument(skip_all, name = "transformers.metadata_qa_code")]
async fn transform_node(&self, mut node: Node) -> Result<Node> {
let prompt = self
.prompt_template
.to_prompt()
.with_node(&node)
.with_context_value("questions", self.num_questions);
let response = self.client.prompt(prompt).await?;
node.metadata.insert(NAME, response);
Ok(node)
}
fn concurrency(&self) -> Option<usize> {
self.concurrency
}
}
#[cfg(test)]
mod test {
use swiftide_core::MockSimplePrompt;
use super::*;
#[tokio::test]
async fn test_template() {
let template = default_prompt();
let prompt = template
.to_prompt()
.with_node(&Node::new("test"))
.with_context_value("questions", 5);
insta::assert_snapshot!(prompt.render().await.unwrap());
}
#[tokio::test]
async fn test_metadata_qacode() {
let mut client = MockSimplePrompt::new();
client
.expect_prompt()
.returning(|_| Ok("Q1: Hello\nA1: World".to_string()));
let transformer = MetadataQACode::builder().client(client).build().unwrap();
let node = Node::new("Some text");
let result = transformer.transform_node(node).await.unwrap();
assert_eq!(
result.metadata.get("Questions and Answers (code)").unwrap(),
"Q1: Hello\nA1: World"
);
}
}