1use futures_util::TryFutureExt as _;
17use std::sync::Arc;
18use swiftide_core::{
19 EvaluateQuery,
20 prelude::*,
21 querying::{
22 Answer, Query, QueryState, QueryStream, Retrieve, SearchStrategy, TransformQuery,
23 TransformResponse, search_strategies::SimilaritySingleEmbedding, states,
24 },
25};
26use tokio::sync::mpsc::Sender;
27
28pub struct Pipeline<
30 'stream,
31 STRATEGY: SearchStrategy = SimilaritySingleEmbedding,
32 STATE: QueryState = states::Pending,
33> {
34 search_strategy: STRATEGY,
35 stream: QueryStream<'stream, STATE>,
36 query_sender: Sender<Result<Query<states::Pending>>>,
37 evaluator: Option<Arc<Box<dyn EvaluateQuery>>>,
38 default_concurrency: usize,
39}
40
41impl Default for Pipeline<'_, SimilaritySingleEmbedding> {
44 fn default() -> Self {
45 let stream = QueryStream::default();
46 Self {
47 search_strategy: SimilaritySingleEmbedding::default(),
48 query_sender: stream
49 .sender
50 .clone()
51 .expect("Pipeline received stream without query entrypoint"),
52 stream,
53 evaluator: None,
54 default_concurrency: num_cpus::get(),
55 }
56 }
57}
58
59impl<'a, STRATEGY: SearchStrategy> Pipeline<'a, STRATEGY> {
60 #[must_use]
66 pub fn from_search_strategy(strategy: STRATEGY) -> Pipeline<'a, STRATEGY> {
67 let stream = QueryStream::default();
68
69 Pipeline {
70 search_strategy: strategy,
71 query_sender: stream
72 .sender
73 .clone()
74 .expect("Pipeline received stream without query entrypoint"),
75 stream,
76 evaluator: None,
77 default_concurrency: num_cpus::get(),
78 }
79 }
80}
81
82impl<'stream: 'static, STRATEGY> Pipeline<'stream, STRATEGY, states::Pending>
83where
84 STRATEGY: SearchStrategy,
85{
86 #[must_use]
88 pub fn evaluate_with<T: EvaluateQuery + 'stream>(mut self, evaluator: T) -> Self {
89 self.evaluator = Some(Arc::new(Box::new(evaluator)));
90
91 self
92 }
93
94 #[must_use]
96 pub fn then_transform_query<T: TransformQuery + 'stream>(
97 self,
98 transformer: T,
99 ) -> Pipeline<'stream, STRATEGY, states::Pending> {
100 let transformer = Arc::new(transformer);
101
102 let Pipeline {
103 stream,
104 query_sender,
105 search_strategy,
106 evaluator,
107 default_concurrency,
108 } = self;
109
110 let new_stream = stream
111 .map_ok(move |query| {
112 let transformer = Arc::clone(&transformer);
113 let span = tracing::info_span!("then_transform_query", query = ?query);
114
115 tokio::spawn(
116 async move {
117 let transformed_query = transformer.transform_query(query).await?;
118 tracing::debug!(
119 transformed_query = transformed_query.current(),
120 query_transformer = transformer.name(),
121 "Transformed query"
122 );
123
124 Ok(transformed_query)
125 }
126 .instrument(span.or_current()),
127 )
128 .err_into::<anyhow::Error>()
129 })
130 .try_buffer_unordered(default_concurrency)
131 .map(|x| x.and_then(|x| x));
132
133 Pipeline {
134 stream: new_stream.boxed().into(),
135 search_strategy,
136 query_sender,
137 evaluator,
138 default_concurrency,
139 }
140 }
141}
142
143impl<'stream: 'static, STRATEGY: SearchStrategy + 'stream>
144 Pipeline<'stream, STRATEGY, states::Pending>
145{
146 #[must_use]
148 pub fn then_retrieve<T: ToOwned<Owned = impl Retrieve<STRATEGY> + 'stream>>(
149 self,
150 retriever: T,
151 ) -> Pipeline<'stream, STRATEGY, states::Retrieved> {
152 let retriever = Arc::new(retriever.to_owned());
153 let Pipeline {
154 stream,
155 query_sender,
156 search_strategy,
157 evaluator,
158 default_concurrency,
159 } = self;
160
161 let strategy_for_stream = search_strategy.clone();
162 let evaluator_for_stream = evaluator.clone();
163
164 let new_stream = stream
165 .map_ok(move |query| {
166 let search_strategy = strategy_for_stream.clone();
167 let retriever = Arc::clone(&retriever);
168 let span = tracing::info_span!("then_retrieve", query = ?query);
169 let evaluator_for_stream = evaluator_for_stream.clone();
170
171 tokio::spawn(
172 async move {
173 let result = retriever.retrieve(&search_strategy, query).await?;
174
175 tracing::debug!(documents = ?result.documents(), "Retrieved documents");
176
177 if let Some(evaluator) = evaluator_for_stream.as_ref() {
178 evaluator.evaluate(result.clone().into()).await?;
179 Ok(result)
180 } else {
181 Ok(result)
182 }
183 }
184 .instrument(span.or_current()),
185 )
186 .err_into::<anyhow::Error>()
187 })
188 .try_buffer_unordered(default_concurrency)
189 .map(|x| x.and_then(|x| x));
190
191 Pipeline {
192 stream: new_stream.boxed().into(),
193 search_strategy: search_strategy.clone(),
194 query_sender,
195 evaluator,
196 default_concurrency,
197 }
198 }
199}
200
201impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, states::Retrieved> {
202 #[must_use]
204 pub fn then_transform_response<T: TransformResponse + 'stream>(
205 self,
206 transformer: T,
207 ) -> Pipeline<'stream, STRATEGY, states::Retrieved> {
208 let transformer = Arc::new(transformer);
209 let Pipeline {
210 stream,
211 query_sender,
212 search_strategy,
213 evaluator,
214 default_concurrency,
215 } = self;
216
217 let new_stream = stream
218 .map_ok(move |query| {
219 let transformer = Arc::clone(&transformer);
220 let span = tracing::info_span!("then_transform_response", query = ?query);
221 tokio::spawn(
222 async move {
223 let transformed_query = transformer.transform_response(query).await?;
224 tracing::debug!(
225 transformed_query = transformed_query.current(),
226 response_transformer = transformer.name(),
227 "Transformed response"
228 );
229
230 Ok(transformed_query)
231 }
232 .instrument(span.or_current()),
233 )
234 .err_into::<anyhow::Error>()
235 })
236 .try_buffer_unordered(default_concurrency)
237 .map(|x| x.and_then(|x| x));
238
239 Pipeline {
240 stream: new_stream.boxed().into(),
241 search_strategy,
242 query_sender,
243 evaluator,
244 default_concurrency,
245 }
246 }
247}
248
249impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, states::Retrieved> {
250 #[must_use]
252 pub fn then_answer<T: Answer + 'stream>(
253 self,
254 answerer: T,
255 ) -> Pipeline<'stream, STRATEGY, states::Answered> {
256 let answerer = Arc::new(answerer);
257 let Pipeline {
258 stream,
259 query_sender,
260 search_strategy,
261 evaluator,
262 default_concurrency,
263 } = self;
264 let evaluator_for_stream = evaluator.clone();
265
266 let new_stream = stream
267 .map_ok(move |query: Query<states::Retrieved>| {
268 let answerer = Arc::clone(&answerer);
269 let span = tracing::info_span!("then_answer", query = ?query);
270 let evaluator_for_stream = evaluator_for_stream.clone();
271
272 tokio::spawn(
273 async move {
274 tracing::debug!(answerer = answerer.name(), "Answering query");
275 let result = answerer.answer(query).await?;
276
277 if let Some(evaluator) = evaluator_for_stream.as_ref() {
278 evaluator.evaluate(result.clone().into()).await?;
279 Ok(result)
280 } else {
281 Ok(result)
282 }
283 }
284 .instrument(span.or_current()),
285 )
286 .err_into::<anyhow::Error>()
287 })
288 .try_buffer_unordered(default_concurrency)
289 .map(|x| x.and_then(|x| x));
290
291 Pipeline {
292 stream: new_stream.boxed().into(),
293 search_strategy,
294 query_sender,
295 evaluator,
296 default_concurrency,
297 }
298 }
299}
300
301impl<STRATEGY: SearchStrategy> Pipeline<'_, STRATEGY, states::Answered> {
302 #[tracing::instrument(skip_all, name = "query_pipeline.query")]
308 pub async fn query(
309 mut self,
310 query: impl Into<Query<states::Pending>>,
311 ) -> Result<Query<states::Answered>> {
312 tracing::debug!("Sending query");
313 let now = std::time::Instant::now();
314
315 self.query_sender.send(Ok(query.into())).await?;
316
317 let answer = self.stream.try_next().await?.ok_or_else(|| {
318 anyhow::anyhow!("Pipeline did not receive a response from the query stream")
319 });
320
321 let elapsed_in_seconds = now.elapsed().as_secs();
322 tracing::warn!(
323 elapsed_in_seconds,
324 "Answered query in {} seconds",
325 elapsed_in_seconds
326 );
327
328 answer
329 }
330
331 #[tracing::instrument(skip_all, name = "query_pipeline.query_mut")]
340 pub async fn query_mut(
341 &mut self,
342 query: impl Into<Query<states::Pending>>,
343 ) -> Result<Query<states::Answered>> {
344 tracing::warn!("Sending query");
345 let now = std::time::Instant::now();
346
347 self.query_sender.send(Ok(query.into())).await?;
348
349 let answer = self
350 .stream
351 .by_ref()
352 .take(1)
353 .try_next()
354 .await?
355 .ok_or_else(|| {
356 anyhow::anyhow!("Pipeline did not receive a response from the query stream")
357 });
358
359 tracing::debug!(?answer, "Received an answer");
360
361 let elapsed_in_seconds = now.elapsed().as_secs();
362 tracing::warn!(
363 elapsed_in_seconds,
364 "Answered query in {} seconds",
365 elapsed_in_seconds
366 );
367
368 answer
369 }
370
371 #[tracing::instrument(skip_all, name = "query_pipeline.query_all")]
378 pub async fn query_all(
379 self,
380 queries: Vec<impl Into<Query<states::Pending>> + Clone>,
381 ) -> Result<Vec<Query<states::Answered>>> {
382 tracing::warn!("Sending queries");
383 let now = std::time::Instant::now();
384
385 let Pipeline {
386 query_sender,
387 mut stream,
388 ..
389 } = self;
390
391 for query in &queries {
392 query_sender.send(Ok(query.clone().into())).await?;
393 }
394 tracing::info!("All queries sent");
395
396 let mut results = vec![];
397 while let Some(result) = stream.try_next().await? {
398 tracing::debug!(?result, "Received an answer");
399 results.push(result);
400 if results.len() == queries.len() {
401 break;
402 }
403 }
404
405 let elapsed_in_seconds = now.elapsed().as_secs();
406 tracing::warn!(
407 num_queries = queries.len(),
408 elapsed_in_seconds,
409 "Answered all queries in {} seconds",
410 elapsed_in_seconds
411 );
412 Ok(results)
413 }
414}
415
416#[cfg(test)]
417mod test {
418 use swiftide_core::{
419 MockAnswer, MockTransformQuery, MockTransformResponse, querying::search_strategies,
420 };
421
422 use super::*;
423
424 #[tokio::test]
425 async fn test_closures_in_each_step() {
426 let pipeline = Pipeline::default()
427 .then_transform_query(move |query: Query<states::Pending>| Ok(query))
428 .then_retrieve(
429 move |_: &search_strategies::SimilaritySingleEmbedding,
430 query: Query<states::Pending>| {
431 Ok(query.retrieved_documents(vec![]))
432 },
433 )
434 .then_transform_response(Ok)
435 .then_answer(move |query: Query<states::Retrieved>| Ok(query.answered("Ok")));
436 let response = pipeline.query("What").await.unwrap();
437 assert_eq!(response.answer(), "Ok");
438 }
439
440 #[tokio::test]
441 async fn test_all_steps_should_accept_dyn_box() {
442 let mut query_transformer = MockTransformQuery::new();
443 query_transformer.expect_transform_query().returning(Ok);
444
445 let mut response_transformer = MockTransformResponse::new();
446 response_transformer
447 .expect_transform_response()
448 .returning(Ok);
449 let mut answer_transformer = MockAnswer::new();
450 answer_transformer
451 .expect_answer()
452 .returning(|query| Ok(query.answered("OK")));
453
454 let pipeline = Pipeline::default()
455 .then_transform_query(Box::new(query_transformer) as Box<dyn TransformQuery>)
456 .then_retrieve(
457 |_: &search_strategies::SimilaritySingleEmbedding,
458 query: Query<states::Pending>| {
459 Ok(query.retrieved_documents(vec![]))
460 },
461 )
462 .then_transform_response(Box::new(response_transformer) as Box<dyn TransformResponse>)
463 .then_answer(Box::new(answer_transformer) as Box<dyn Answer>);
464 let response = pipeline.query("What").await.unwrap();
465 assert_eq!(response.answer(), "OK");
466 }
467
468 #[tokio::test]
469 async fn test_reuse_with_query_mut() {
470 let mut pipeline = Pipeline::default()
471 .then_transform_query(move |query: Query<states::Pending>| Ok(query))
472 .then_retrieve(
473 move |_: &search_strategies::SimilaritySingleEmbedding,
474 query: Query<states::Pending>| {
475 Ok(query.retrieved_documents(vec![]))
476 },
477 )
478 .then_transform_response(Ok)
479 .then_answer(move |query: Query<states::Retrieved>| Ok(query.answered("Ok")));
480
481 let response = pipeline.query_mut("What").await.unwrap();
482 assert_eq!(response.answer(), "Ok");
483 let response = pipeline.query_mut("What").await.unwrap();
484 assert_eq!(response.answer(), "Ok");
485 }
486}