swiftide_query/answers/
simple.rs

1//! Generate an answer based on the current query
2use std::sync::Arc;
3use swiftide_core::{
4    Answer,
5    document::Document,
6    indexing::SimplePrompt,
7    prelude::*,
8    prompt::Prompt,
9    querying::{Query, states},
10};
11
12/// Generate an answer based on the current query
13///
14/// For most general purposes, this transformer should provide a sensible default. It takes either
15/// a transformation that has already been applied to the documents (in `Query::current`), or the
16/// documents themselves, and will then feed them as context with the _original_ question to an llm
17/// to generate an answer.
18///
19/// Optionally, a custom document template can be provided to render the documents in a specific
20/// way.
21#[derive(Debug, Clone, Builder)]
22pub struct Simple {
23    #[builder(setter(custom))]
24    client: Arc<dyn SimplePrompt>,
25    #[builder(default = "default_prompt()")]
26    prompt_template: Prompt,
27    #[builder(default, setter(into, strip_option))]
28    document_template: Option<Prompt>,
29}
30
31impl Simple {
32    pub fn builder() -> SimpleBuilder {
33        SimpleBuilder::default()
34    }
35
36    /// Builds a new simple answer generator from a client that implements [`SimplePrompt`].
37    ///
38    /// # Panics
39    ///
40    /// Panics if the build failed
41    pub fn from_client(client: impl SimplePrompt + 'static) -> Simple {
42        SimpleBuilder::default()
43            .client(client)
44            .to_owned()
45            .build()
46            .expect("Failed to build Simple")
47    }
48}
49
50impl SimpleBuilder {
51    pub fn client(&mut self, client: impl SimplePrompt + 'static) -> &mut Self {
52        self.client = Some(Arc::new(client) as Arc<dyn SimplePrompt>);
53        self
54    }
55}
56
57fn default_prompt() -> Prompt {
58    indoc::indoc! {"
59    Answer the following question based on the context provided:
60    {{ question }}
61
62    ## Constraints
63    * Do not include any information that is not in the provided context.
64    * If the question cannot be answered by the provided context, state that it cannot be answered.
65    * Answer the question completely and format it as markdown.
66
67    ## Context
68
69    ---
70    {{ documents }}
71    ---
72    "}
73    .into()
74}
75
76#[async_trait]
77impl Answer for Simple {
78    #[tracing::instrument(skip_all)]
79    async fn answer(&self, query: Query<states::Retrieved>) -> Result<Query<states::Answered>> {
80        let mut context = tera::Context::new();
81
82        context.insert("question", query.original());
83
84        let documents = if !query.current().is_empty() {
85            query.current().to_string()
86        } else if let Some(template) = &self.document_template {
87            let mut rendered_documents = Vec::new();
88            for document in query.documents() {
89                let rendered = template
90                    .clone()
91                    .with_context(tera::Context::from_serialize(document)?)
92                    .render()?;
93                rendered_documents.push(rendered);
94            }
95
96            rendered_documents.join("\n---\n")
97        } else {
98            query
99                .documents()
100                .iter()
101                .map(Document::content)
102                .collect::<Vec<_>>()
103                .join("\n---\n")
104        };
105        context.insert("documents", &documents);
106
107        let answer = self
108            .client
109            .prompt(self.prompt_template.clone().with_context(context))
110            .await?;
111
112        Ok(query.answered(answer))
113    }
114}
115
116#[cfg(test)]
117mod test {
118    use std::sync::Mutex;
119
120    use insta::assert_snapshot;
121    use swiftide_core::{MockSimplePrompt, indexing::Metadata};
122
123    use super::*;
124
125    assert_default_prompt_snapshot!("question" => "What is love?", "documents" => "My context");
126
127    #[tokio::test]
128    async fn test_uses_current_if_present() {
129        let mut mock_client = MockSimplePrompt::new();
130
131        // I'll buy a beer for the first person who can think of a less insane way to do this
132        let received_prompt = Arc::new(Mutex::new(None));
133        let cloned = received_prompt.clone();
134        mock_client
135            .expect_prompt()
136            .withf(move |prompt| {
137                cloned.lock().unwrap().replace(prompt.clone());
138                true
139            })
140            .once()
141            .returning(|_| Ok(String::default()));
142
143        let documents = vec![
144            Document::new("First document", Some(Metadata::from(("some", "metadata")))),
145            Document::new(
146                "Second document",
147                Some(Metadata::from(("other", "metadata"))),
148            ),
149        ];
150        let query: Query<states::Retrieved> = Query::builder()
151            .original("original")
152            .current("A fictional generated summary")
153            .state(states::Retrieved)
154            .documents(documents)
155            .build()
156            .unwrap();
157
158        let transformer = Simple::builder().client(mock_client).build().unwrap();
159
160        transformer.answer(query).await.unwrap();
161
162        let received_prompt = received_prompt.lock().unwrap().take().unwrap();
163        let rendered = received_prompt.render().unwrap();
164        assert_snapshot!(rendered);
165    }
166
167    #[tokio::test]
168    async fn test_custom_document_template() {
169        let mut mock_client = MockSimplePrompt::new();
170
171        // I'll buy a beer for the first person who can think of a less insane way to do this
172        let received_prompt = Arc::new(Mutex::new(None));
173        let cloned = received_prompt.clone();
174        mock_client
175            .expect_prompt()
176            .withf(move |prompt| {
177                cloned.lock().unwrap().replace(prompt.clone());
178                true
179            })
180            .once()
181            .returning(|_| Ok(String::default()));
182
183        let documents = vec![
184            Document::new("First document", Some(Metadata::from(("some", "metadata")))),
185            Document::new(
186                "Second document",
187                Some(Metadata::from(("other", "metadata"))),
188            ),
189        ];
190        let query: Query<states::Retrieved> = Query::builder()
191            .original("original")
192            .current(String::default())
193            .state(states::Retrieved)
194            .documents(documents)
195            .build()
196            .unwrap();
197
198        let transformer = Simple::builder()
199            .client(mock_client)
200            .document_template(indoc::indoc! {"
201                {% for key, value in metadata -%}
202                    {{ key }}: {{ value }}
203                {% endfor -%}
204
205                {{ content }}"})
206            .build()
207            .unwrap();
208
209        transformer.answer(query).await.unwrap();
210
211        let received_prompt = received_prompt.lock().unwrap().take().unwrap();
212        let rendered = received_prompt.render().unwrap();
213        assert_snapshot!(rendered);
214    }
215}