swiftide_query/answers/
simple.rs1use std::sync::Arc;
3use swiftide_core::{
4 Answer,
5 document::Document,
6 indexing::SimplePrompt,
7 prelude::*,
8 prompt::Prompt,
9 querying::{Query, states},
10};
11
12#[derive(Debug, Clone, Builder)]
22pub struct Simple {
23 #[builder(setter(custom))]
24 client: Arc<dyn SimplePrompt>,
25 #[builder(default = "default_prompt()")]
26 prompt_template: Prompt,
27 #[builder(default, setter(into, strip_option))]
28 document_template: Option<Prompt>,
29}
30
31impl Simple {
32 pub fn builder() -> SimpleBuilder {
33 SimpleBuilder::default()
34 }
35
36 pub fn from_client(client: impl SimplePrompt + 'static) -> Simple {
42 SimpleBuilder::default()
43 .client(client)
44 .to_owned()
45 .build()
46 .expect("Failed to build Simple")
47 }
48}
49
50impl SimpleBuilder {
51 pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self {
52 self.client = Some(Arc::new(client) as Arc<dyn SimplePrompt>);
53 self
54 }
55}
56
57fn default_prompt() -> Prompt {
58 indoc::indoc! {"
59 Answer the following question based on the context provided:
60 {{ question }}
61
62 ## Constraints
63 * Do not include any information that is not in the provided context.
64 * If the question cannot be answered by the provided context, state that it cannot be answered.
65 * Answer the question completely and format it as markdown.
66
67 ## Context
68
69 ---
70 {{ documents }}
71 ---
72 "}
73 .into()
74}
75
76#[async_trait]
77impl Answer for Simple {
78 #[tracing::instrument(skip_all)]
79 async fn answer(&self, query: Query<states::Retrieved>) -> Result<Query<states::Answered>> {
80 let mut context = tera::Context::new();
81
82 context.insert("question", query.original());
83
84 let documents = if !query.current().is_empty() {
85 query.current().to_string()
86 } else if let Some(template) = &self.document_template {
87 let mut rendered_documents = Vec::new();
88 for document in query.documents() {
89 let rendered = template
90 .clone()
91 .with_context(tera::Context::from_serialize(document)?)
92 .render()?;
93 rendered_documents.push(rendered);
94 }
95
96 rendered_documents.join("\n---\n")
97 } else {
98 query
99 .documents()
100 .iter()
101 .map(Document::content)
102 .collect::<Vec<_>>()
103 .join("\n---\n")
104 };
105 context.insert("documents", &documents);
106
107 let answer = self
108 .client
109 .prompt(self.prompt_template.clone().with_context(context))
110 .await?;
111
112 Ok(query.answered(answer))
113 }
114}
115
116#[cfg(test)]
117mod test {
118 use std::sync::Mutex;
119
120 use insta::assert_snapshot;
121 use swiftide_core::{MockSimplePrompt, indexing::Metadata};
122
123 use super::*;
124
125 assert_default_prompt_snapshot!("question" => "What is love?", "documents" => "My context");
126
127 #[tokio::test]
128 async fn test_uses_current_if_present() {
129 let mut mock_client = MockSimplePrompt::new();
130
131 let received_prompt = Arc::new(Mutex::new(None));
133 let cloned = received_prompt.clone();
134 mock_client
135 .expect_prompt()
136 .withf(move |prompt| {
137 cloned.lock().unwrap().replace(prompt.clone());
138 true
139 })
140 .once()
141 .returning(|_| Ok(String::default()));
142
143 let documents = vec![
144 Document::new("First document", Some(Metadata::from(("some", "metadata")))),
145 Document::new(
146 "Second document",
147 Some(Metadata::from(("other", "metadata"))),
148 ),
149 ];
150 let query: Query<states::Retrieved> = Query::builder()
151 .original("original")
152 .current("A fictional generated summary")
153 .state(states::Retrieved)
154 .documents(documents)
155 .build()
156 .unwrap();
157
158 let transformer = Simple::builder().client(mock_client).build().unwrap();
159
160 transformer.answer(query).await.unwrap();
161
162 let received_prompt = received_prompt.lock().unwrap().take().unwrap();
163 let rendered = received_prompt.render().unwrap();
164 assert_snapshot!(rendered);
165 }
166
167 #[tokio::test]
168 async fn test_custom_document_template() {
169 let mut mock_client = MockSimplePrompt::new();
170
171 let received_prompt = Arc::new(Mutex::new(None));
173 let cloned = received_prompt.clone();
174 mock_client
175 .expect_prompt()
176 .withf(move |prompt| {
177 cloned.lock().unwrap().replace(prompt.clone());
178 true
179 })
180 .once()
181 .returning(|_| Ok(String::default()));
182
183 let documents = vec![
184 Document::new("First document", Some(Metadata::from(("some", "metadata")))),
185 Document::new(
186 "Second document",
187 Some(Metadata::from(("other", "metadata"))),
188 ),
189 ];
190 let query: Query<states::Retrieved> = Query::builder()
191 .original("original")
192 .current(String::default())
193 .state(states::Retrieved)
194 .documents(documents)
195 .build()
196 .unwrap();
197
198 let transformer = Simple::builder()
199 .client(mock_client)
200 .document_template(indoc::indoc! {"
201 {% for key, value in metadata -%}
202 {{ key }}: {{ value }}
203 {% endfor -%}
204
205 {{ content }}"})
206 .build()
207 .unwrap();
208
209 transformer.answer(query).await.unwrap();
210
211 let received_prompt = received_prompt.lock().unwrap().take().unwrap();
212 let rendered = received_prompt.render().unwrap();
213 assert_snapshot!(rendered);
214 }
215}