swiftide_query/answers/
simple.rs

1//! Generate an answer based on the current query
2use std::sync::Arc;
3use swiftide_core::{
4    document::Document,
5    indexing::SimplePrompt,
6    prelude::*,
7    querying::{states, Query},
8    template::Template,
9    Answer,
10};
11
12/// Generate an answer based on the current query
13///
14/// For most general purposes, this transformer should provide a sensible default. It takes either
15/// a transformation that has already been applied to the documents (in `Query::current`), or the documents themselves,
16/// and will then feed them as context with the _original_ question to an llm to generate an
17/// answer.
18///
19/// Optionally, a custom document template can be provided to render the documents in a specific way.
20#[derive(Debug, Clone, Builder)]
21pub struct Simple {
22    #[builder(setter(custom))]
23    client: Arc<dyn SimplePrompt>,
24    #[builder(default = "default_prompt()")]
25    prompt_template: Template,
26    #[builder(default, setter(into, strip_option))]
27    document_template: Option<Template>,
28}
29
30impl Simple {
31    pub fn builder() -> SimpleBuilder {
32        SimpleBuilder::default()
33    }
34
35    /// Builds a new simple answer generator from a client that implements [`SimplePrompt`].
36    ///
37    /// # Panics
38    ///
39    /// Panics if the build failed
40    pub fn from_client(client: impl SimplePrompt + 'static) -> Simple {
41        SimpleBuilder::default()
42            .client(client)
43            .to_owned()
44            .build()
45            .expect("Failed to build Simple")
46    }
47}
48
49impl SimpleBuilder {
50    pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self {
51        self.client = Some(Arc::new(client) as Arc<dyn SimplePrompt>);
52        self
53    }
54}
55
56fn default_prompt() -> Template {
57    indoc::indoc! {"
58    Answer the following question based on the context provided:
59    {{ question }}
60
61    ## Constraints
62    * Do not include any information that is not in the provided context.
63    * If the question cannot be answered by the provided context, state that it cannot be answered.
64    * Answer the question completely and format it as markdown.
65
66    ## Context
67
68    ---
69    {{ documents }}
70    ---
71    "}
72    .into()
73}
74
75#[async_trait]
76impl Answer for Simple {
77    #[tracing::instrument(skip_all)]
78    async fn answer(&self, query: Query<states::Retrieved>) -> Result<Query<states::Answered>> {
79        let mut context = tera::Context::new();
80
81        context.insert("question", query.original());
82
83        let documents = if !query.current().is_empty() {
84            query.current().to_string()
85        } else if let Some(template) = &self.document_template {
86            let mut rendered_documents = Vec::new();
87            for document in query.documents() {
88                let rendered = template
89                    .render(&tera::Context::from_serialize(document)?)
90                    .await?;
91                rendered_documents.push(rendered);
92            }
93
94            rendered_documents.join("\n---\n")
95        } else {
96            query
97                .documents()
98                .iter()
99                .map(Document::content)
100                .collect::<Vec<_>>()
101                .join("\n---\n")
102        };
103        context.insert("documents", &documents);
104
105        let answer = self
106            .client
107            .prompt(self.prompt_template.to_prompt().with_context(context))
108            .await?;
109
110        Ok(query.answered(answer))
111    }
112}
113
114#[cfg(test)]
115mod test {
116    use std::sync::Mutex;
117
118    use insta::assert_snapshot;
119    use swiftide_core::{indexing::Metadata, MockSimplePrompt};
120
121    use super::*;
122
123    assert_default_prompt_snapshot!("question" => "What is love?", "documents" => "My context");
124
125    #[tokio::test]
126    async fn test_uses_current_if_present() {
127        let mut mock_client = MockSimplePrompt::new();
128
129        // I'll buy a beer for the first person who can think of a less insane way to do this
130        let received_prompt = Arc::new(Mutex::new(None));
131        let cloned = received_prompt.clone();
132        mock_client
133            .expect_prompt()
134            .withf(move |prompt| {
135                cloned.lock().unwrap().replace(prompt.clone());
136                true
137            })
138            .once()
139            .returning(|_| Ok(String::default()));
140
141        let documents = vec![
142            Document::new("First document", Some(Metadata::from(("some", "metadata")))),
143            Document::new(
144                "Second document",
145                Some(Metadata::from(("other", "metadata"))),
146            ),
147        ];
148        let query: Query<states::Retrieved> = Query::builder()
149            .original("original")
150            .current("A fictional generated summary")
151            .state(states::Retrieved)
152            .documents(documents)
153            .build()
154            .unwrap();
155
156        let transformer = Simple::builder().client(mock_client).build().unwrap();
157
158        transformer.answer(query).await.unwrap();
159
160        let received_prompt = received_prompt.lock().unwrap().take().unwrap();
161        let rendered = received_prompt.render().await.unwrap();
162        assert_snapshot!(rendered);
163    }
164
165    #[tokio::test]
166    async fn test_custom_document_template() {
167        let mut mock_client = MockSimplePrompt::new();
168
169        // I'll buy a beer for the first person who can think of a less insane way to do this
170        let received_prompt = Arc::new(Mutex::new(None));
171        let cloned = received_prompt.clone();
172        mock_client
173            .expect_prompt()
174            .withf(move |prompt| {
175                cloned.lock().unwrap().replace(prompt.clone());
176                true
177            })
178            .once()
179            .returning(|_| Ok(String::default()));
180
181        let documents = vec![
182            Document::new("First document", Some(Metadata::from(("some", "metadata")))),
183            Document::new(
184                "Second document",
185                Some(Metadata::from(("other", "metadata"))),
186            ),
187        ];
188        let query: Query<states::Retrieved> = Query::builder()
189            .original("original")
190            .current(String::default())
191            .state(states::Retrieved)
192            .documents(documents)
193            .build()
194            .unwrap();
195
196        let transformer = Simple::builder()
197            .client(mock_client)
198            .document_template(indoc::indoc! {"
199                {% for key, value in metadata -%}
200                    {{ key }}: {{ value }}
201                {% endfor -%}
202
203                {{ content }}"})
204            .build()
205            .unwrap();
206
207        transformer.answer(query).await.unwrap();
208
209        let received_prompt = received_prompt.lock().unwrap().take().unwrap();
210        let rendered = received_prompt.render().await.unwrap();
211        assert_snapshot!(rendered);
212    }
213}