swiftide_query/query/
pipeline.rs

1//! A query pipeline can be used to answer a user query
2//!
3//! The pipeline has a sequence of steps:
4//!     1. Transform the query (i.e. Generating subquestions, embeddings)
5//!     2. Retrieve documents from storage
6//!     3. Transform these documents into a suitable context for answering
7//!     4. Answering the query
8//!
9//! WARN: The query pipeline is in a very early stage!
10//!
11//! Under the hood, it uses a [`SearchStrategy`] that an implementor of [`Retrieve`] (i.e. Qdrant)
12//! must implement.
13//!
14//! A query pipeline is lazy and only runs when query is called.
15
16use futures_util::TryFutureExt as _;
17use std::sync::Arc;
18use swiftide_core::{
19    EvaluateQuery,
20    prelude::*,
21    querying::{
22        Answer, Query, QueryState, QueryStream, Retrieve, SearchStrategy, TransformQuery,
23        TransformResponse, search_strategies::SimilaritySingleEmbedding, states,
24    },
25};
26use tokio::sync::mpsc::Sender;
27
28/// The starting point of a query pipeline
29pub struct Pipeline<
30    'stream,
31    STRATEGY: SearchStrategy = SimilaritySingleEmbedding,
32    STATE: QueryState = states::Pending,
33> {
34    search_strategy: STRATEGY,
35    stream: QueryStream<'stream, STATE>,
36    query_sender: Sender<Result<Query<states::Pending>>>,
37    evaluator: Option<Arc<Box<dyn EvaluateQuery>>>,
38    default_concurrency: usize,
39}
40
41/// By default the [`SearchStrategy`] is [`SimilaritySingleEmbedding`], which embed the current
42/// query and returns a collection of documents.
43impl Default for Pipeline<'_, SimilaritySingleEmbedding> {
44    fn default() -> Self {
45        let stream = QueryStream::default();
46        Self {
47            search_strategy: SimilaritySingleEmbedding::default(),
48            query_sender: stream
49                .sender
50                .clone()
51                .expect("Pipeline received stream without query entrypoint"),
52            stream,
53            evaluator: None,
54            default_concurrency: num_cpus::get(),
55        }
56    }
57}
58
59impl<'a, STRATEGY: SearchStrategy> Pipeline<'a, STRATEGY> {
60    /// Create a query pipeline from a [`SearchStrategy`]
61    ///
62    /// # Panics
63    ///
64    /// Panics if the inner stream fails to build
65    #[must_use]
66    pub fn from_search_strategy(strategy: STRATEGY) -> Pipeline<'a, STRATEGY> {
67        let stream = QueryStream::default();
68
69        Pipeline {
70            search_strategy: strategy,
71            query_sender: stream
72                .sender
73                .clone()
74                .expect("Pipeline received stream without query entrypoint"),
75            stream,
76            evaluator: None,
77            default_concurrency: num_cpus::get(),
78        }
79    }
80}
81
82impl<'stream: 'static, STRATEGY> Pipeline<'stream, STRATEGY, states::Pending>
83where
84    STRATEGY: SearchStrategy,
85{
86    /// Evaluate queries with an evaluator
87    #[must_use]
88    pub fn evaluate_with<T: EvaluateQuery + 'stream>(mut self, evaluator: T) -> Self {
89        self.evaluator = Some(Arc::new(Box::new(evaluator)));
90
91        self
92    }
93
94    /// Transform a query into something else, see [`crate::query_transformers`]
95    #[must_use]
96    pub fn then_transform_query<T: TransformQuery + 'stream>(
97        self,
98        transformer: T,
99    ) -> Pipeline<'stream, STRATEGY, states::Pending> {
100        let transformer = Arc::new(transformer);
101
102        let Pipeline {
103            stream,
104            query_sender,
105            search_strategy,
106            evaluator,
107            default_concurrency,
108        } = self;
109
110        let new_stream = stream
111            .map_ok(move |query| {
112                let transformer = Arc::clone(&transformer);
113                let span = tracing::info_span!("then_transform_query", query = ?query);
114
115                tokio::spawn(
116                    async move {
117                        let transformed_query = transformer.transform_query(query).await?;
118                        tracing::debug!(
119                            transformed_query = transformed_query.current(),
120                            query_transformer = transformer.name(),
121                            "Transformed query"
122                        );
123
124                        Ok(transformed_query)
125                    }
126                    .instrument(span.or_current()),
127                )
128                .err_into::<anyhow::Error>()
129            })
130            .try_buffer_unordered(default_concurrency)
131            .map(|x| x.and_then(|x| x));
132
133        Pipeline {
134            stream: new_stream.boxed().into(),
135            search_strategy,
136            query_sender,
137            evaluator,
138            default_concurrency,
139        }
140    }
141}
142
143impl<'stream: 'static, STRATEGY: SearchStrategy + 'stream>
144    Pipeline<'stream, STRATEGY, states::Pending>
145{
146    /// Executes the query based on a search query with a retriever
147    #[must_use]
148    pub fn then_retrieve<T: ToOwned<Owned = impl Retrieve<STRATEGY> + 'stream>>(
149        self,
150        retriever: T,
151    ) -> Pipeline<'stream, STRATEGY, states::Retrieved> {
152        let retriever = Arc::new(retriever.to_owned());
153        let Pipeline {
154            stream,
155            query_sender,
156            search_strategy,
157            evaluator,
158            default_concurrency,
159        } = self;
160
161        let strategy_for_stream = search_strategy.clone();
162        let evaluator_for_stream = evaluator.clone();
163
164        let new_stream = stream
165            .map_ok(move |query| {
166                let search_strategy = strategy_for_stream.clone();
167                let retriever = Arc::clone(&retriever);
168                let span = tracing::info_span!("then_retrieve", query = ?query);
169                let evaluator_for_stream = evaluator_for_stream.clone();
170
171                tokio::spawn(
172                    async move {
173                        let result = retriever.retrieve(&search_strategy, query).await?;
174
175                        tracing::debug!(documents = ?result.documents(), "Retrieved documents");
176
177                        if let Some(evaluator) = evaluator_for_stream.as_ref() {
178                            evaluator.evaluate(result.clone().into()).await?;
179                            Ok(result)
180                        } else {
181                            Ok(result)
182                        }
183                    }
184                    .instrument(span.or_current()),
185                )
186                .err_into::<anyhow::Error>()
187            })
188            .try_buffer_unordered(default_concurrency)
189            .map(|x| x.and_then(|x| x));
190
191        Pipeline {
192            stream: new_stream.boxed().into(),
193            search_strategy: search_strategy.clone(),
194            query_sender,
195            evaluator,
196            default_concurrency,
197        }
198    }
199}
200
201impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, states::Retrieved> {
202    /// Transforms a retrieved query into something else
203    #[must_use]
204    pub fn then_transform_response<T: TransformResponse + 'stream>(
205        self,
206        transformer: T,
207    ) -> Pipeline<'stream, STRATEGY, states::Retrieved> {
208        let transformer = Arc::new(transformer);
209        let Pipeline {
210            stream,
211            query_sender,
212            search_strategy,
213            evaluator,
214            default_concurrency,
215        } = self;
216
217        let new_stream = stream
218            .map_ok(move |query| {
219                let transformer = Arc::clone(&transformer);
220                let span = tracing::info_span!("then_transform_response", query = ?query);
221                tokio::spawn(
222                    async move {
223                        let transformed_query = transformer.transform_response(query).await?;
224                        tracing::debug!(
225                            transformed_query = transformed_query.current(),
226                            response_transformer = transformer.name(),
227                            "Transformed response"
228                        );
229
230                        Ok(transformed_query)
231                    }
232                    .instrument(span.or_current()),
233                )
234                .err_into::<anyhow::Error>()
235            })
236            .try_buffer_unordered(default_concurrency)
237            .map(|x| x.and_then(|x| x));
238
239        Pipeline {
240            stream: new_stream.boxed().into(),
241            search_strategy,
242            query_sender,
243            evaluator,
244            default_concurrency,
245        }
246    }
247}
248
249impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, states::Retrieved> {
250    /// Generates an answer based on previous transformations
251    #[must_use]
252    pub fn then_answer<T: Answer + 'stream>(
253        self,
254        answerer: T,
255    ) -> Pipeline<'stream, STRATEGY, states::Answered> {
256        let answerer = Arc::new(answerer);
257        let Pipeline {
258            stream,
259            query_sender,
260            search_strategy,
261            evaluator,
262            default_concurrency,
263        } = self;
264        let evaluator_for_stream = evaluator.clone();
265
266        let new_stream = stream
267            .map_ok(move |query: Query<states::Retrieved>| {
268                let answerer = Arc::clone(&answerer);
269                let span = tracing::info_span!("then_answer", query = ?query);
270                let evaluator_for_stream = evaluator_for_stream.clone();
271
272                tokio::spawn(
273                    async move {
274                        tracing::debug!(answerer = answerer.name(), "Answering query");
275                        let result = answerer.answer(query).await?;
276
277                        if let Some(evaluator) = evaluator_for_stream.as_ref() {
278                            evaluator.evaluate(result.clone().into()).await?;
279                            Ok(result)
280                        } else {
281                            Ok(result)
282                        }
283                    }
284                    .instrument(span.or_current()),
285                )
286                .err_into::<anyhow::Error>()
287            })
288            .try_buffer_unordered(default_concurrency)
289            .map(|x| x.and_then(|x| x));
290
291        Pipeline {
292            stream: new_stream.boxed().into(),
293            search_strategy,
294            query_sender,
295            evaluator,
296            default_concurrency,
297        }
298    }
299}
300
301impl<STRATEGY: SearchStrategy> Pipeline<'_, STRATEGY, states::Answered> {
302    /// Runs the pipeline with a user query, accepts `&str` as well.
303    ///
304    /// # Errors
305    ///
306    /// Errors if any of the transformations failed or no response was found
307    #[tracing::instrument(skip_all, name = "query_pipeline.query")]
308    pub async fn query(
309        mut self,
310        query: impl Into<Query<states::Pending>>,
311    ) -> Result<Query<states::Answered>> {
312        tracing::debug!("Sending query");
313        let now = std::time::Instant::now();
314
315        self.query_sender.send(Ok(query.into())).await?;
316
317        let answer = self.stream.try_next().await?.ok_or_else(|| {
318            anyhow::anyhow!("Pipeline did not receive a response from the query stream")
319        });
320
321        let elapsed_in_seconds = now.elapsed().as_secs();
322        tracing::warn!(
323            elapsed_in_seconds,
324            "Answered query in {} seconds",
325            elapsed_in_seconds
326        );
327
328        answer
329    }
330
331    /// Runs the pipeline with a user query, accepts `&str` as well.
332    ///
333    /// Does not consume the pipeline and requires a mutable reference. This allows
334    /// the pipeline to be reused.
335    ///
336    /// # Errors
337    ///
338    /// Errors if any of the transformations failed or no response was found
339    #[tracing::instrument(skip_all, name = "query_pipeline.query_mut")]
340    pub async fn query_mut(
341        &mut self,
342        query: impl Into<Query<states::Pending>>,
343    ) -> Result<Query<states::Answered>> {
344        tracing::warn!("Sending query");
345        let now = std::time::Instant::now();
346
347        self.query_sender.send(Ok(query.into())).await?;
348
349        let answer = self
350            .stream
351            .by_ref()
352            .take(1)
353            .try_next()
354            .await?
355            .ok_or_else(|| {
356                anyhow::anyhow!("Pipeline did not receive a response from the query stream")
357            });
358
359        tracing::debug!(?answer, "Received an answer");
360
361        let elapsed_in_seconds = now.elapsed().as_secs();
362        tracing::warn!(
363            elapsed_in_seconds,
364            "Answered query in {} seconds",
365            elapsed_in_seconds
366        );
367
368        answer
369    }
370
371    /// Runs the pipeline with multiple queries
372    ///
373    /// # Errors
374    ///
375    /// Errors if any of the transformations failed, no response was found, or the stream was
376    /// closed.
377    #[tracing::instrument(skip_all, name = "query_pipeline.query_all")]
378    pub async fn query_all(
379        self,
380        queries: Vec<impl Into<Query<states::Pending>> + Clone>,
381    ) -> Result<Vec<Query<states::Answered>>> {
382        tracing::warn!("Sending queries");
383        let now = std::time::Instant::now();
384
385        let Pipeline {
386            query_sender,
387            mut stream,
388            ..
389        } = self;
390
391        for query in &queries {
392            query_sender.send(Ok(query.clone().into())).await?;
393        }
394        tracing::info!("All queries sent");
395
396        let mut results = vec![];
397        while let Some(result) = stream.try_next().await? {
398            tracing::debug!(?result, "Received an answer");
399            results.push(result);
400            if results.len() == queries.len() {
401                break;
402            }
403        }
404
405        let elapsed_in_seconds = now.elapsed().as_secs();
406        tracing::warn!(
407            num_queries = queries.len(),
408            elapsed_in_seconds,
409            "Answered all queries in {} seconds",
410            elapsed_in_seconds
411        );
412        Ok(results)
413    }
414}
415
416#[cfg(test)]
417mod test {
418    use swiftide_core::{
419        MockAnswer, MockTransformQuery, MockTransformResponse, querying::search_strategies,
420    };
421
422    use super::*;
423
424    #[tokio::test]
425    async fn test_closures_in_each_step() {
426        let pipeline = Pipeline::default()
427            .then_transform_query(move |query: Query<states::Pending>| Ok(query))
428            .then_retrieve(
429                move |_: &search_strategies::SimilaritySingleEmbedding,
430                      query: Query<states::Pending>| {
431                    Ok(query.retrieved_documents(vec![]))
432                },
433            )
434            .then_transform_response(Ok)
435            .then_answer(move |query: Query<states::Retrieved>| Ok(query.answered("Ok")));
436        let response = pipeline.query("What").await.unwrap();
437        assert_eq!(response.answer(), "Ok");
438    }
439
440    #[tokio::test]
441    async fn test_all_steps_should_accept_dyn_box() {
442        let mut query_transformer = MockTransformQuery::new();
443        query_transformer.expect_transform_query().returning(Ok);
444
445        let mut response_transformer = MockTransformResponse::new();
446        response_transformer
447            .expect_transform_response()
448            .returning(Ok);
449        let mut answer_transformer = MockAnswer::new();
450        answer_transformer
451            .expect_answer()
452            .returning(|query| Ok(query.answered("OK")));
453
454        let pipeline = Pipeline::default()
455            .then_transform_query(Box::new(query_transformer) as Box<dyn TransformQuery>)
456            .then_retrieve(
457                |_: &search_strategies::SimilaritySingleEmbedding,
458                 query: Query<states::Pending>| {
459                    Ok(query.retrieved_documents(vec![]))
460                },
461            )
462            .then_transform_response(Box::new(response_transformer) as Box<dyn TransformResponse>)
463            .then_answer(Box::new(answer_transformer) as Box<dyn Answer>);
464        let response = pipeline.query("What").await.unwrap();
465        assert_eq!(response.answer(), "OK");
466    }
467
468    #[tokio::test]
469    async fn test_reuse_with_query_mut() {
470        let mut pipeline = Pipeline::default()
471            .then_transform_query(move |query: Query<states::Pending>| Ok(query))
472            .then_retrieve(
473                move |_: &search_strategies::SimilaritySingleEmbedding,
474                      query: Query<states::Pending>| {
475                    Ok(query.retrieved_documents(vec![]))
476                },
477            )
478            .then_transform_response(Ok)
479            .then_answer(move |query: Query<states::Retrieved>| Ok(query.answered("Ok")));
480
481        let response = pipeline.query_mut("What").await.unwrap();
482        assert_eq!(response.answer(), "Ok");
483        let response = pipeline.query_mut("What").await.unwrap();
484        assert_eq!(response.answer(), "Ok");
485    }
486}