swiftide_query/answers/
simple.rs1use std::sync::Arc;
3use swiftide_core::{
4 document::Document,
5 indexing::SimplePrompt,
6 prelude::*,
7 querying::{states, Query},
8 template::Template,
9 Answer,
10};
11
12#[derive(Debug, Clone, Builder)]
22pub struct Simple {
23 #[builder(setter(custom))]
24 client: Arc<dyn SimplePrompt>,
25 #[builder(default = "default_prompt()")]
26 prompt_template: Template,
27 #[builder(default, setter(into, strip_option))]
28 document_template: Option<Template>,
29}
30
31impl Simple {
32 pub fn builder() -> SimpleBuilder {
33 SimpleBuilder::default()
34 }
35
36 pub fn from_client(client: impl SimplePrompt + 'static) -> Simple {
42 SimpleBuilder::default()
43 .client(client)
44 .to_owned()
45 .build()
46 .expect("Failed to build Simple")
47 }
48}
49
50impl SimpleBuilder {
51 pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self {
52 self.client = Some(Arc::new(client) as Arc<dyn SimplePrompt>);
53 self
54 }
55}
56
57fn default_prompt() -> Template {
58 indoc::indoc! {"
59 Answer the following question based on the context provided:
60 {{ question }}
61
62 ## Constraints
63 * Do not include any information that is not in the provided context.
64 * If the question cannot be answered by the provided context, state that it cannot be answered.
65 * Answer the question completely and format it as markdown.
66
67 ## Context
68
69 ---
70 {{ documents }}
71 ---
72 "}
73 .into()
74}
75
76#[async_trait]
77impl Answer for Simple {
78 #[tracing::instrument(skip_all)]
79 async fn answer(&self, query: Query<states::Retrieved>) -> Result<Query<states::Answered>> {
80 let mut context = tera::Context::new();
81
82 context.insert("question", query.original());
83
84 let documents = if !query.current().is_empty() {
85 query.current().to_string()
86 } else if let Some(template) = &self.document_template {
87 let mut rendered_documents = Vec::new();
88 for document in query.documents() {
89 let rendered = template
90 .render(&tera::Context::from_serialize(document)?)
91 .await?;
92 rendered_documents.push(rendered);
93 }
94
95 rendered_documents.join("\n---\n")
96 } else {
97 query
98 .documents()
99 .iter()
100 .map(Document::content)
101 .collect::<Vec<_>>()
102 .join("\n---\n")
103 };
104 context.insert("documents", &documents);
105
106 let answer = self
107 .client
108 .prompt(self.prompt_template.to_prompt().with_context(context))
109 .await?;
110
111 Ok(query.answered(answer))
112 }
113}
114
115#[cfg(test)]
116mod test {
117 use std::sync::Mutex;
118
119 use insta::assert_snapshot;
120 use swiftide_core::{indexing::Metadata, MockSimplePrompt};
121
122 use super::*;
123
124 assert_default_prompt_snapshot!("question" => "What is love?", "documents" => "My context");
125
126 #[tokio::test]
127 async fn test_uses_current_if_present() {
128 let mut mock_client = MockSimplePrompt::new();
129
130 let received_prompt = Arc::new(Mutex::new(None));
132 let cloned = received_prompt.clone();
133 mock_client
134 .expect_prompt()
135 .withf(move |prompt| {
136 cloned.lock().unwrap().replace(prompt.clone());
137 true
138 })
139 .once()
140 .returning(|_| Ok(String::default()));
141
142 let documents = vec![
143 Document::new("First document", Some(Metadata::from(("some", "metadata")))),
144 Document::new(
145 "Second document",
146 Some(Metadata::from(("other", "metadata"))),
147 ),
148 ];
149 let query: Query<states::Retrieved> = Query::builder()
150 .original("original")
151 .current("A fictional generated summary")
152 .state(states::Retrieved)
153 .documents(documents)
154 .build()
155 .unwrap();
156
157 let transformer = Simple::builder().client(mock_client).build().unwrap();
158
159 transformer.answer(query).await.unwrap();
160
161 let received_prompt = received_prompt.lock().unwrap().take().unwrap();
162 let rendered = received_prompt.render().await.unwrap();
163 assert_snapshot!(rendered);
164 }
165
166 #[tokio::test]
167 async fn test_custom_document_template() {
168 let mut mock_client = MockSimplePrompt::new();
169
170 let received_prompt = Arc::new(Mutex::new(None));
172 let cloned = received_prompt.clone();
173 mock_client
174 .expect_prompt()
175 .withf(move |prompt| {
176 cloned.lock().unwrap().replace(prompt.clone());
177 true
178 })
179 .once()
180 .returning(|_| Ok(String::default()));
181
182 let documents = vec![
183 Document::new("First document", Some(Metadata::from(("some", "metadata")))),
184 Document::new(
185 "Second document",
186 Some(Metadata::from(("other", "metadata"))),
187 ),
188 ];
189 let query: Query<states::Retrieved> = Query::builder()
190 .original("original")
191 .current(String::default())
192 .state(states::Retrieved)
193 .documents(documents)
194 .build()
195 .unwrap();
196
197 let transformer = Simple::builder()
198 .client(mock_client)
199 .document_template(indoc::indoc! {"
200 {% for key, value in metadata -%}
201 {{ key }}: {{ value }}
202 {% endfor -%}
203
204 {{ content }}"})
205 .build()
206 .unwrap();
207
208 transformer.answer(query).await.unwrap();
209
210 let received_prompt = received_prompt.lock().unwrap().take().unwrap();
211 let rendered = received_prompt.render().await.unwrap();
212 assert_snapshot!(rendered);
213 }
214}