swiftide_query/answers/
simple.rs

1//! Generate an answer based on the current query
2use std::sync::Arc;
3use swiftide_core::{
4    document::Document,
5    indexing::SimplePrompt,
6    prelude::*,
7    querying::{states, Query},
8    template::Template,
9    Answer,
10};
11
12/// Generate an answer based on the current query
13///
14/// For most general purposes, this transformer should provide a sensible default. It takes either
15/// a transformation that has already been applied to the documents (in `Query::current`), or the
16/// documents themselves, and will then feed them as context with the _original_ question to an llm
17/// to generate an answer.
18///
19/// Optionally, a custom document template can be provided to render the documents in a specific
20/// way.
21#[derive(Debug, Clone, Builder)]
22pub struct Simple {
23    #[builder(setter(custom))]
24    client: Arc<dyn SimplePrompt>,
25    #[builder(default = "default_prompt()")]
26    prompt_template: Template,
27    #[builder(default, setter(into, strip_option))]
28    document_template: Option<Template>,
29}
30
31impl Simple {
32    pub fn builder() -> SimpleBuilder {
33        SimpleBuilder::default()
34    }
35
36    /// Builds a new simple answer generator from a client that implements [`SimplePrompt`].
37    ///
38    /// # Panics
39    ///
40    /// Panics if the build failed
41    pub fn from_client(client: impl SimplePrompt + 'static) -> Simple {
42        SimpleBuilder::default()
43            .client(client)
44            .to_owned()
45            .build()
46            .expect("Failed to build Simple")
47    }
48}
49
50impl SimpleBuilder {
51    pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self {
52        self.client = Some(Arc::new(client) as Arc<dyn SimplePrompt>);
53        self
54    }
55}
56
57fn default_prompt() -> Template {
58    indoc::indoc! {"
59    Answer the following question based on the context provided:
60    {{ question }}
61
62    ## Constraints
63    * Do not include any information that is not in the provided context.
64    * If the question cannot be answered by the provided context, state that it cannot be answered.
65    * Answer the question completely and format it as markdown.
66
67    ## Context
68
69    ---
70    {{ documents }}
71    ---
72    "}
73    .into()
74}
75
76#[async_trait]
77impl Answer for Simple {
78    #[tracing::instrument(skip_all)]
79    async fn answer(&self, query: Query<states::Retrieved>) -> Result<Query<states::Answered>> {
80        let mut context = tera::Context::new();
81
82        context.insert("question", query.original());
83
84        let documents = if !query.current().is_empty() {
85            query.current().to_string()
86        } else if let Some(template) = &self.document_template {
87            let mut rendered_documents = Vec::new();
88            for document in query.documents() {
89                let rendered = template
90                    .render(&tera::Context::from_serialize(document)?)
91                    .await?;
92                rendered_documents.push(rendered);
93            }
94
95            rendered_documents.join("\n---\n")
96        } else {
97            query
98                .documents()
99                .iter()
100                .map(Document::content)
101                .collect::<Vec<_>>()
102                .join("\n---\n")
103        };
104        context.insert("documents", &documents);
105
106        let answer = self
107            .client
108            .prompt(self.prompt_template.to_prompt().with_context(context))
109            .await?;
110
111        Ok(query.answered(answer))
112    }
113}
114
115#[cfg(test)]
116mod test {
117    use std::sync::Mutex;
118
119    use insta::assert_snapshot;
120    use swiftide_core::{indexing::Metadata, MockSimplePrompt};
121
122    use super::*;
123
124    assert_default_prompt_snapshot!("question" => "What is love?", "documents" => "My context");
125
126    #[tokio::test]
127    async fn test_uses_current_if_present() {
128        let mut mock_client = MockSimplePrompt::new();
129
130        // I'll buy a beer for the first person who can think of a less insane way to do this
131        let received_prompt = Arc::new(Mutex::new(None));
132        let cloned = received_prompt.clone();
133        mock_client
134            .expect_prompt()
135            .withf(move |prompt| {
136                cloned.lock().unwrap().replace(prompt.clone());
137                true
138            })
139            .once()
140            .returning(|_| Ok(String::default()));
141
142        let documents = vec![
143            Document::new("First document", Some(Metadata::from(("some", "metadata")))),
144            Document::new(
145                "Second document",
146                Some(Metadata::from(("other", "metadata"))),
147            ),
148        ];
149        let query: Query<states::Retrieved> = Query::builder()
150            .original("original")
151            .current("A fictional generated summary")
152            .state(states::Retrieved)
153            .documents(documents)
154            .build()
155            .unwrap();
156
157        let transformer = Simple::builder().client(mock_client).build().unwrap();
158
159        transformer.answer(query).await.unwrap();
160
161        let received_prompt = received_prompt.lock().unwrap().take().unwrap();
162        let rendered = received_prompt.render().await.unwrap();
163        assert_snapshot!(rendered);
164    }
165
166    #[tokio::test]
167    async fn test_custom_document_template() {
168        let mut mock_client = MockSimplePrompt::new();
169
170        // I'll buy a beer for the first person who can think of a less insane way to do this
171        let received_prompt = Arc::new(Mutex::new(None));
172        let cloned = received_prompt.clone();
173        mock_client
174            .expect_prompt()
175            .withf(move |prompt| {
176                cloned.lock().unwrap().replace(prompt.clone());
177                true
178            })
179            .once()
180            .returning(|_| Ok(String::default()));
181
182        let documents = vec![
183            Document::new("First document", Some(Metadata::from(("some", "metadata")))),
184            Document::new(
185                "Second document",
186                Some(Metadata::from(("other", "metadata"))),
187            ),
188        ];
189        let query: Query<states::Retrieved> = Query::builder()
190            .original("original")
191            .current(String::default())
192            .state(states::Retrieved)
193            .documents(documents)
194            .build()
195            .unwrap();
196
197        let transformer = Simple::builder()
198            .client(mock_client)
199            .document_template(indoc::indoc! {"
200                {% for key, value in metadata -%}
201                    {{ key }}: {{ value }}
202                {% endfor -%}
203
204                {{ content }}"})
205            .build()
206            .unwrap();
207
208        transformer.answer(query).await.unwrap();
209
210        let received_prompt = received_prompt.lock().unwrap().take().unwrap();
211        let rendered = received_prompt.render().await.unwrap();
212        assert_snapshot!(rendered);
213    }
214}