swiftide_query/answers/
simple.rs1use std::sync::Arc;
3use swiftide_core::{
4 document::Document,
5 indexing::SimplePrompt,
6 prelude::*,
7 querying::{states, Query},
8 template::Template,
9 Answer,
10};
11
12#[derive(Debug, Clone, Builder)]
21pub struct Simple {
22 #[builder(setter(custom))]
23 client: Arc<dyn SimplePrompt>,
24 #[builder(default = "default_prompt()")]
25 prompt_template: Template,
26 #[builder(default, setter(into, strip_option))]
27 document_template: Option<Template>,
28}
29
30impl Simple {
31 pub fn builder() -> SimpleBuilder {
32 SimpleBuilder::default()
33 }
34
35 pub fn from_client(client: impl SimplePrompt + 'static) -> Simple {
41 SimpleBuilder::default()
42 .client(client)
43 .to_owned()
44 .build()
45 .expect("Failed to build Simple")
46 }
47}
48
49impl SimpleBuilder {
50 pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self {
51 self.client = Some(Arc::new(client) as Arc<dyn SimplePrompt>);
52 self
53 }
54}
55
56fn default_prompt() -> Template {
57 indoc::indoc! {"
58 Answer the following question based on the context provided:
59 {{ question }}
60
61 ## Constraints
62 * Do not include any information that is not in the provided context.
63 * If the question cannot be answered by the provided context, state that it cannot be answered.
64 * Answer the question completely and format it as markdown.
65
66 ## Context
67
68 ---
69 {{ documents }}
70 ---
71 "}
72 .into()
73}
74
75#[async_trait]
76impl Answer for Simple {
77 #[tracing::instrument(skip_all)]
78 async fn answer(&self, query: Query<states::Retrieved>) -> Result<Query<states::Answered>> {
79 let mut context = tera::Context::new();
80
81 context.insert("question", query.original());
82
83 let documents = if !query.current().is_empty() {
84 query.current().to_string()
85 } else if let Some(template) = &self.document_template {
86 let mut rendered_documents = Vec::new();
87 for document in query.documents() {
88 let rendered = template
89 .render(&tera::Context::from_serialize(document)?)
90 .await?;
91 rendered_documents.push(rendered);
92 }
93
94 rendered_documents.join("\n---\n")
95 } else {
96 query
97 .documents()
98 .iter()
99 .map(Document::content)
100 .collect::<Vec<_>>()
101 .join("\n---\n")
102 };
103 context.insert("documents", &documents);
104
105 let answer = self
106 .client
107 .prompt(self.prompt_template.to_prompt().with_context(context))
108 .await?;
109
110 Ok(query.answered(answer))
111 }
112}
113
114#[cfg(test)]
115mod test {
116 use std::sync::Mutex;
117
118 use insta::assert_snapshot;
119 use swiftide_core::{indexing::Metadata, MockSimplePrompt};
120
121 use super::*;
122
123 assert_default_prompt_snapshot!("question" => "What is love?", "documents" => "My context");
124
125 #[tokio::test]
126 async fn test_uses_current_if_present() {
127 let mut mock_client = MockSimplePrompt::new();
128
129 let received_prompt = Arc::new(Mutex::new(None));
131 let cloned = received_prompt.clone();
132 mock_client
133 .expect_prompt()
134 .withf(move |prompt| {
135 cloned.lock().unwrap().replace(prompt.clone());
136 true
137 })
138 .once()
139 .returning(|_| Ok(String::default()));
140
141 let documents = vec![
142 Document::new("First document", Some(Metadata::from(("some", "metadata")))),
143 Document::new(
144 "Second document",
145 Some(Metadata::from(("other", "metadata"))),
146 ),
147 ];
148 let query: Query<states::Retrieved> = Query::builder()
149 .original("original")
150 .current("A fictional generated summary")
151 .state(states::Retrieved)
152 .documents(documents)
153 .build()
154 .unwrap();
155
156 let transformer = Simple::builder().client(mock_client).build().unwrap();
157
158 transformer.answer(query).await.unwrap();
159
160 let received_prompt = received_prompt.lock().unwrap().take().unwrap();
161 let rendered = received_prompt.render().await.unwrap();
162 assert_snapshot!(rendered);
163 }
164
165 #[tokio::test]
166 async fn test_custom_document_template() {
167 let mut mock_client = MockSimplePrompt::new();
168
169 let received_prompt = Arc::new(Mutex::new(None));
171 let cloned = received_prompt.clone();
172 mock_client
173 .expect_prompt()
174 .withf(move |prompt| {
175 cloned.lock().unwrap().replace(prompt.clone());
176 true
177 })
178 .once()
179 .returning(|_| Ok(String::default()));
180
181 let documents = vec![
182 Document::new("First document", Some(Metadata::from(("some", "metadata")))),
183 Document::new(
184 "Second document",
185 Some(Metadata::from(("other", "metadata"))),
186 ),
187 ];
188 let query: Query<states::Retrieved> = Query::builder()
189 .original("original")
190 .current(String::default())
191 .state(states::Retrieved)
192 .documents(documents)
193 .build()
194 .unwrap();
195
196 let transformer = Simple::builder()
197 .client(mock_client)
198 .document_template(indoc::indoc! {"
199 {% for key, value in metadata -%}
200 {{ key }}: {{ value }}
201 {% endfor -%}
202
203 {{ content }}"})
204 .build()
205 .unwrap();
206
207 transformer.answer(query).await.unwrap();
208
209 let received_prompt = received_prompt.lock().unwrap().take().unwrap();
210 let rendered = received_prompt.render().await.unwrap();
211 assert_snapshot!(rendered);
212 }
213}