swiftide_query/query/
pipeline.rs

1//! A query pipeline can be used to answer a user query
2//!
3//! The pipeline has a sequence of steps:
4//!     1. Transform the query (i.e. Generating subquestions, embeddings)
5//!     2. Retrieve documents from storage
6//!     3. Transform these documents into a suitable context for answering
7//!     4. Answering the query
8//!
9//! WARN: The query pipeline is in a very early stage!
10//!
11//! Under the hood, it uses a [`SearchStrategy`] that an implementor of [`Retrieve`] (i.e. Qdrant)
12//! must implement.
13//!
14//! A query pipeline is lazy and only runs when query is called.
15
16use futures_util::TryFutureExt as _;
17use std::sync::Arc;
18use swiftide_core::{
19    EvaluateQuery,
20    prelude::*,
21    querying::{
22        Answer, Query, QueryState, QueryStream, Retrieve, SearchStrategy, TransformQuery,
23        TransformResponse, search_strategies::SimilaritySingleEmbedding, states,
24    },
25};
26use tokio::sync::mpsc::Sender;
27
28/// The starting point of a query pipeline
29pub struct Pipeline<
30    'stream,
31    STRATEGY: SearchStrategy = SimilaritySingleEmbedding,
32    STATE: QueryState = states::Pending,
33> {
34    search_strategy: STRATEGY,
35    stream: QueryStream<'stream, STATE>,
36    query_sender: Sender<Result<Query<states::Pending>>>,
37    evaluator: Option<Arc<Box<dyn EvaluateQuery>>>,
38    default_concurrency: usize,
39}
40
41/// By default the [`SearchStrategy`] is [`SimilaritySingleEmbedding`], which embed the current
42/// query and returns a collection of documents.
43impl Default for Pipeline<'_, SimilaritySingleEmbedding> {
44    fn default() -> Self {
45        let stream = QueryStream::default();
46        Self {
47            search_strategy: SimilaritySingleEmbedding::default(),
48            query_sender: stream
49                .sender
50                .clone()
51                .expect("Pipeline received stream without query entrypoint"),
52            stream,
53            evaluator: None,
54            default_concurrency: num_cpus::get(),
55        }
56    }
57}
58
59impl<'a, STRATEGY: SearchStrategy> Pipeline<'a, STRATEGY> {
60    /// Create a query pipeline from a [`SearchStrategy`]
61    ///
62    /// # Panics
63    ///
64    /// Panics if the inner stream fails to build
65    #[must_use]
66    pub fn from_search_strategy(strategy: STRATEGY) -> Pipeline<'a, STRATEGY> {
67        let stream = QueryStream::default();
68
69        Pipeline {
70            search_strategy: strategy,
71            query_sender: stream
72                .sender
73                .clone()
74                .expect("Pipeline received stream without query entrypoint"),
75            stream,
76            evaluator: None,
77            default_concurrency: num_cpus::get(),
78        }
79    }
80}
81
82impl<'stream: 'static, STRATEGY> Pipeline<'stream, STRATEGY, states::Pending>
83where
84    STRATEGY: SearchStrategy,
85{
86    /// Evaluate queries with an evaluator
87    #[must_use]
88    pub fn evaluate_with<T: EvaluateQuery + 'stream>(mut self, evaluator: T) -> Self {
89        self.evaluator = Some(Arc::new(Box::new(evaluator)));
90
91        self
92    }
93
94    /// Transform a query into something else, see [`crate::query_transformers`]
95    #[must_use]
96    pub fn then_transform_query<T: TransformQuery + 'stream>(
97        self,
98        transformer: T,
99    ) -> Pipeline<'stream, STRATEGY, states::Pending> {
100        let transformer = Arc::new(transformer);
101
102        let Pipeline {
103            stream,
104            query_sender,
105            search_strategy,
106            evaluator,
107            default_concurrency,
108        } = self;
109
110        let new_stream = stream
111            .map_ok(move |query| {
112                let transformer = Arc::clone(&transformer);
113                let span = tracing::info_span!("then_transform_query", query = ?query);
114
115                tokio::spawn(
116                    async move {
117                        let transformed_query = transformer.transform_query(query).await?;
118                        tracing::debug!(
119                            transformed_query = transformed_query.current(),
120                            query_transformer = transformer.name(),
121                            "Transformed query"
122                        );
123
124                        Ok(transformed_query)
125                    }
126                    .instrument(span.or_current()),
127                )
128                .err_into::<anyhow::Error>()
129            })
130            .try_buffer_unordered(default_concurrency)
131            .map(|x| x.and_then(|x| x));
132
133        Pipeline {
134            stream: new_stream.boxed().into(),
135            search_strategy,
136            query_sender,
137            evaluator,
138            default_concurrency,
139        }
140    }
141}
142
143impl<'stream: 'static, STRATEGY: SearchStrategy + 'stream>
144    Pipeline<'stream, STRATEGY, states::Pending>
145{
146    /// Executes the query based on a search query with a retriever
147    #[must_use]
148    pub fn then_retrieve<T: ToOwned<Owned = impl Retrieve<STRATEGY> + 'stream>>(
149        self,
150        retriever: T,
151    ) -> Pipeline<'stream, STRATEGY, states::Retrieved> {
152        let retriever = Arc::new(retriever.to_owned());
153        let Pipeline {
154            stream,
155            query_sender,
156            search_strategy,
157            evaluator,
158            default_concurrency,
159        } = self;
160
161        let strategy_for_stream = search_strategy.clone();
162        let evaluator_for_stream = evaluator.clone();
163
164        let new_stream = stream
165            .map_ok(move |query| {
166                let search_strategy = strategy_for_stream.clone();
167                let retriever = Arc::clone(&retriever);
168                let span = tracing::info_span!("then_retrieve", query = ?query);
169                let evaluator_for_stream = evaluator_for_stream.clone();
170
171                tokio::spawn(
172                    async move {
173                        let result = retriever.retrieve(&search_strategy, query).await?;
174
175                        tracing::debug!(
176                            num_documents = result.documents().len(),
177                            total_bytes = result
178                                .documents()
179                                .iter()
180                                .map(|d| d.bytes().len())
181                                .sum::<usize>(),
182                            "Retrieved documents"
183                        );
184
185                        if let Some(evaluator) = evaluator_for_stream.as_ref() {
186                            evaluator.evaluate(result.clone().into()).await?;
187                            Ok(result)
188                        } else {
189                            Ok(result)
190                        }
191                    }
192                    .instrument(span.or_current()),
193                )
194                .err_into::<anyhow::Error>()
195            })
196            .try_buffer_unordered(default_concurrency)
197            .map(|x| x.and_then(|x| x));
198
199        Pipeline {
200            stream: new_stream.boxed().into(),
201            search_strategy: search_strategy.clone(),
202            query_sender,
203            evaluator,
204            default_concurrency,
205        }
206    }
207}
208
209impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, states::Retrieved> {
210    /// Transforms a retrieved query into something else
211    #[must_use]
212    pub fn then_transform_response<T: TransformResponse + 'stream>(
213        self,
214        transformer: T,
215    ) -> Pipeline<'stream, STRATEGY, states::Retrieved> {
216        let transformer = Arc::new(transformer);
217        let Pipeline {
218            stream,
219            query_sender,
220            search_strategy,
221            evaluator,
222            default_concurrency,
223        } = self;
224
225        let new_stream = stream
226            .map_ok(move |query| {
227                let transformer = Arc::clone(&transformer);
228                let span = tracing::info_span!("then_transform_response", query = ?query);
229                tokio::spawn(
230                    async move {
231                        let transformed_query = transformer.transform_response(query).await?;
232                        tracing::debug!(
233                            transformed_query = transformed_query.current(),
234                            response_transformer = transformer.name(),
235                            "Transformed response"
236                        );
237
238                        Ok(transformed_query)
239                    }
240                    .instrument(span.or_current()),
241                )
242                .err_into::<anyhow::Error>()
243            })
244            .try_buffer_unordered(default_concurrency)
245            .map(|x| x.and_then(|x| x));
246
247        Pipeline {
248            stream: new_stream.boxed().into(),
249            search_strategy,
250            query_sender,
251            evaluator,
252            default_concurrency,
253        }
254    }
255}
256
257impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, states::Retrieved> {
258    /// Generates an answer based on previous transformations
259    #[must_use]
260    pub fn then_answer<T: Answer + 'stream>(
261        self,
262        answerer: T,
263    ) -> Pipeline<'stream, STRATEGY, states::Answered> {
264        let answerer = Arc::new(answerer);
265        let Pipeline {
266            stream,
267            query_sender,
268            search_strategy,
269            evaluator,
270            default_concurrency,
271        } = self;
272        let evaluator_for_stream = evaluator.clone();
273
274        let new_stream = stream
275            .map_ok(move |query: Query<states::Retrieved>| {
276                let answerer = Arc::clone(&answerer);
277                let span = tracing::info_span!("then_answer", query = ?query);
278                let evaluator_for_stream = evaluator_for_stream.clone();
279
280                tokio::spawn(
281                    async move {
282                        tracing::debug!(answerer = answerer.name(), "Answering query");
283                        let result = answerer.answer(query).await?;
284
285                        if let Some(evaluator) = evaluator_for_stream.as_ref() {
286                            evaluator.evaluate(result.clone().into()).await?;
287                            Ok(result)
288                        } else {
289                            Ok(result)
290                        }
291                    }
292                    .instrument(span.or_current()),
293                )
294                .err_into::<anyhow::Error>()
295            })
296            .try_buffer_unordered(default_concurrency)
297            .map(|x| x.and_then(|x| x));
298
299        Pipeline {
300            stream: new_stream.boxed().into(),
301            search_strategy,
302            query_sender,
303            evaluator,
304            default_concurrency,
305        }
306    }
307}
308
309impl<STRATEGY: SearchStrategy> Pipeline<'_, STRATEGY, states::Answered> {
310    /// Runs the pipeline with a user query, accepts `&str` as well.
311    ///
312    /// # Errors
313    ///
314    /// Errors if any of the transformations failed or no response was found
315    #[tracing::instrument(skip_all, name = "query_pipeline.query")]
316    pub async fn query(
317        mut self,
318        query: impl Into<Query<states::Pending>>,
319    ) -> Result<Query<states::Answered>> {
320        tracing::debug!("Sending query");
321        let now = std::time::Instant::now();
322
323        self.query_sender.send(Ok(query.into())).await?;
324
325        let answer = self.stream.try_next().await?.ok_or_else(|| {
326            anyhow::anyhow!("Pipeline did not receive a response from the query stream")
327        });
328
329        let elapsed_in_seconds = now.elapsed().as_secs();
330        tracing::warn!(
331            elapsed_in_seconds,
332            "Answered query in {} seconds",
333            elapsed_in_seconds
334        );
335
336        answer
337    }
338
339    /// Runs the pipeline with a user query, accepts `&str` as well.
340    ///
341    /// Does not consume the pipeline and requires a mutable reference. This allows
342    /// the pipeline to be reused.
343    ///
344    /// # Errors
345    ///
346    /// Errors if any of the transformations failed or no response was found
347    #[tracing::instrument(skip_all, name = "query_pipeline.query_mut")]
348    pub async fn query_mut(
349        &mut self,
350        query: impl Into<Query<states::Pending>>,
351    ) -> Result<Query<states::Answered>> {
352        tracing::warn!("Sending query");
353        let now = std::time::Instant::now();
354
355        self.query_sender.send(Ok(query.into())).await?;
356
357        let answer = self
358            .stream
359            .by_ref()
360            .take(1)
361            .try_next()
362            .await?
363            .ok_or_else(|| {
364                anyhow::anyhow!("Pipeline did not receive a response from the query stream")
365            });
366
367        tracing::debug!(?answer, "Received an answer");
368
369        let elapsed_in_seconds = now.elapsed().as_secs();
370        tracing::warn!(
371            elapsed_in_seconds,
372            "Answered query in {} seconds",
373            elapsed_in_seconds
374        );
375
376        answer
377    }
378
379    /// Runs the pipeline with multiple queries
380    ///
381    /// # Errors
382    ///
383    /// Errors if any of the transformations failed, no response was found, or the stream was
384    /// closed.
385    #[tracing::instrument(skip_all, name = "query_pipeline.query_all")]
386    pub async fn query_all(
387        self,
388        queries: Vec<impl Into<Query<states::Pending>> + Clone>,
389    ) -> Result<Vec<Query<states::Answered>>> {
390        tracing::warn!("Sending queries");
391        let now = std::time::Instant::now();
392
393        let Pipeline {
394            query_sender,
395            mut stream,
396            ..
397        } = self;
398
399        for query in &queries {
400            query_sender.send(Ok(query.clone().into())).await?;
401        }
402        tracing::info!("All queries sent");
403
404        let mut results = vec![];
405        while let Some(result) = stream.try_next().await? {
406            tracing::debug!(?result, "Received an answer");
407            results.push(result);
408            if results.len() == queries.len() {
409                break;
410            }
411        }
412
413        let elapsed_in_seconds = now.elapsed().as_secs();
414        tracing::warn!(
415            num_queries = queries.len(),
416            elapsed_in_seconds,
417            "Answered all queries in {} seconds",
418            elapsed_in_seconds
419        );
420        Ok(results)
421    }
422}
423
424#[cfg(test)]
425mod test {
426    use swiftide_core::{
427        MockAnswer, MockTransformQuery, MockTransformResponse, querying::search_strategies,
428    };
429
430    use super::*;
431
432    #[tokio::test]
433    async fn test_closures_in_each_step() {
434        let pipeline = Pipeline::default()
435            .then_transform_query(move |query: Query<states::Pending>| Ok(query))
436            .then_retrieve(
437                move |_: &search_strategies::SimilaritySingleEmbedding,
438                      query: Query<states::Pending>| {
439                    Ok(query.retrieved_documents(vec![]))
440                },
441            )
442            .then_transform_response(Ok)
443            .then_answer(move |query: Query<states::Retrieved>| Ok(query.answered("Ok")));
444        let response = pipeline.query("What").await.unwrap();
445        assert_eq!(response.answer(), "Ok");
446    }
447
448    #[tokio::test]
449    async fn test_all_steps_should_accept_dyn_box() {
450        let mut query_transformer = MockTransformQuery::new();
451        query_transformer.expect_transform_query().returning(Ok);
452
453        let mut response_transformer = MockTransformResponse::new();
454        response_transformer
455            .expect_transform_response()
456            .returning(Ok);
457        let mut answer_transformer = MockAnswer::new();
458        answer_transformer
459            .expect_answer()
460            .returning(|query| Ok(query.answered("OK")));
461
462        let pipeline = Pipeline::default()
463            .then_transform_query(Box::new(query_transformer) as Box<dyn TransformQuery>)
464            .then_retrieve(
465                |_: &search_strategies::SimilaritySingleEmbedding,
466                 query: Query<states::Pending>| {
467                    Ok(query.retrieved_documents(vec![]))
468                },
469            )
470            .then_transform_response(Box::new(response_transformer) as Box<dyn TransformResponse>)
471            .then_answer(Box::new(answer_transformer) as Box<dyn Answer>);
472        let response = pipeline.query("What").await.unwrap();
473        assert_eq!(response.answer(), "OK");
474    }
475
476    #[tokio::test]
477    async fn test_reuse_with_query_mut() {
478        let mut pipeline = Pipeline::default()
479            .then_transform_query(move |query: Query<states::Pending>| Ok(query))
480            .then_retrieve(
481                move |_: &search_strategies::SimilaritySingleEmbedding,
482                      query: Query<states::Pending>| {
483                    Ok(query.retrieved_documents(vec![]))
484                },
485            )
486            .then_transform_response(Ok)
487            .then_answer(move |query: Query<states::Retrieved>| Ok(query.answered("Ok")));
488
489        let response = pipeline.query_mut("What").await.unwrap();
490        assert_eq!(response.answer(), "Ok");
491        let response = pipeline.query_mut("What").await.unwrap();
492        assert_eq!(response.answer(), "Ok");
493    }
494}