1use futures_util::TryFutureExt as _;
17use std::sync::Arc;
18use swiftide_core::{
19 EvaluateQuery,
20 prelude::*,
21 querying::{
22 Answer, Query, QueryState, QueryStream, Retrieve, SearchStrategy, TransformQuery,
23 TransformResponse, search_strategies::SimilaritySingleEmbedding, states,
24 },
25};
26use tokio::sync::mpsc::Sender;
27
28pub struct Pipeline<
30 'stream,
31 STRATEGY: SearchStrategy = SimilaritySingleEmbedding,
32 STATE: QueryState = states::Pending,
33> {
34 search_strategy: STRATEGY,
35 stream: QueryStream<'stream, STATE>,
36 query_sender: Sender<Result<Query<states::Pending>>>,
37 evaluator: Option<Arc<Box<dyn EvaluateQuery>>>,
38 default_concurrency: usize,
39}
40
41impl Default for Pipeline<'_, SimilaritySingleEmbedding> {
44 fn default() -> Self {
45 let stream = QueryStream::default();
46 Self {
47 search_strategy: SimilaritySingleEmbedding::default(),
48 query_sender: stream
49 .sender
50 .clone()
51 .expect("Pipeline received stream without query entrypoint"),
52 stream,
53 evaluator: None,
54 default_concurrency: num_cpus::get(),
55 }
56 }
57}
58
59impl<'a, STRATEGY: SearchStrategy> Pipeline<'a, STRATEGY> {
60 #[must_use]
66 pub fn from_search_strategy(strategy: STRATEGY) -> Pipeline<'a, STRATEGY> {
67 let stream = QueryStream::default();
68
69 Pipeline {
70 search_strategy: strategy,
71 query_sender: stream
72 .sender
73 .clone()
74 .expect("Pipeline received stream without query entrypoint"),
75 stream,
76 evaluator: None,
77 default_concurrency: num_cpus::get(),
78 }
79 }
80}
81
82impl<'stream: 'static, STRATEGY> Pipeline<'stream, STRATEGY, states::Pending>
83where
84 STRATEGY: SearchStrategy,
85{
86 #[must_use]
88 pub fn evaluate_with<T: EvaluateQuery + 'stream>(mut self, evaluator: T) -> Self {
89 self.evaluator = Some(Arc::new(Box::new(evaluator)));
90
91 self
92 }
93
94 #[must_use]
96 pub fn then_transform_query<T: TransformQuery + 'stream>(
97 self,
98 transformer: T,
99 ) -> Pipeline<'stream, STRATEGY, states::Pending> {
100 let transformer = Arc::new(transformer);
101
102 let Pipeline {
103 stream,
104 query_sender,
105 search_strategy,
106 evaluator,
107 default_concurrency,
108 } = self;
109
110 let new_stream = stream
111 .map_ok(move |query| {
112 let transformer = Arc::clone(&transformer);
113 let span = tracing::info_span!("then_transform_query", query = ?query);
114
115 tokio::spawn(
116 async move {
117 let transformed_query = transformer.transform_query(query).await?;
118 tracing::debug!(
119 transformed_query = transformed_query.current(),
120 query_transformer = transformer.name(),
121 "Transformed query"
122 );
123
124 Ok(transformed_query)
125 }
126 .instrument(span.or_current()),
127 )
128 .err_into::<anyhow::Error>()
129 })
130 .try_buffer_unordered(default_concurrency)
131 .map(|x| x.and_then(|x| x));
132
133 Pipeline {
134 stream: new_stream.boxed().into(),
135 search_strategy,
136 query_sender,
137 evaluator,
138 default_concurrency,
139 }
140 }
141}
142
143impl<'stream: 'static, STRATEGY: SearchStrategy + 'stream>
144 Pipeline<'stream, STRATEGY, states::Pending>
145{
146 #[must_use]
148 pub fn then_retrieve<T: ToOwned<Owned = impl Retrieve<STRATEGY> + 'stream>>(
149 self,
150 retriever: T,
151 ) -> Pipeline<'stream, STRATEGY, states::Retrieved> {
152 let retriever = Arc::new(retriever.to_owned());
153 let Pipeline {
154 stream,
155 query_sender,
156 search_strategy,
157 evaluator,
158 default_concurrency,
159 } = self;
160
161 let strategy_for_stream = search_strategy.clone();
162 let evaluator_for_stream = evaluator.clone();
163
164 let new_stream = stream
165 .map_ok(move |query| {
166 let search_strategy = strategy_for_stream.clone();
167 let retriever = Arc::clone(&retriever);
168 let span = tracing::info_span!("then_retrieve", query = ?query);
169 let evaluator_for_stream = evaluator_for_stream.clone();
170
171 tokio::spawn(
172 async move {
173 let result = retriever.retrieve(&search_strategy, query).await?;
174
175 tracing::debug!(
176 num_documents = result.documents().len(),
177 total_bytes = result
178 .documents()
179 .iter()
180 .map(|d| d.bytes().len())
181 .sum::<usize>(),
182 "Retrieved documents"
183 );
184
185 if let Some(evaluator) = evaluator_for_stream.as_ref() {
186 evaluator.evaluate(result.clone().into()).await?;
187 Ok(result)
188 } else {
189 Ok(result)
190 }
191 }
192 .instrument(span.or_current()),
193 )
194 .err_into::<anyhow::Error>()
195 })
196 .try_buffer_unordered(default_concurrency)
197 .map(|x| x.and_then(|x| x));
198
199 Pipeline {
200 stream: new_stream.boxed().into(),
201 search_strategy: search_strategy.clone(),
202 query_sender,
203 evaluator,
204 default_concurrency,
205 }
206 }
207}
208
209impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, states::Retrieved> {
210 #[must_use]
212 pub fn then_transform_response<T: TransformResponse + 'stream>(
213 self,
214 transformer: T,
215 ) -> Pipeline<'stream, STRATEGY, states::Retrieved> {
216 let transformer = Arc::new(transformer);
217 let Pipeline {
218 stream,
219 query_sender,
220 search_strategy,
221 evaluator,
222 default_concurrency,
223 } = self;
224
225 let new_stream = stream
226 .map_ok(move |query| {
227 let transformer = Arc::clone(&transformer);
228 let span = tracing::info_span!("then_transform_response", query = ?query);
229 tokio::spawn(
230 async move {
231 let transformed_query = transformer.transform_response(query).await?;
232 tracing::debug!(
233 transformed_query = transformed_query.current(),
234 response_transformer = transformer.name(),
235 "Transformed response"
236 );
237
238 Ok(transformed_query)
239 }
240 .instrument(span.or_current()),
241 )
242 .err_into::<anyhow::Error>()
243 })
244 .try_buffer_unordered(default_concurrency)
245 .map(|x| x.and_then(|x| x));
246
247 Pipeline {
248 stream: new_stream.boxed().into(),
249 search_strategy,
250 query_sender,
251 evaluator,
252 default_concurrency,
253 }
254 }
255}
256
257impl<'stream: 'static, STRATEGY: SearchStrategy> Pipeline<'stream, STRATEGY, states::Retrieved> {
258 #[must_use]
260 pub fn then_answer<T: Answer + 'stream>(
261 self,
262 answerer: T,
263 ) -> Pipeline<'stream, STRATEGY, states::Answered> {
264 let answerer = Arc::new(answerer);
265 let Pipeline {
266 stream,
267 query_sender,
268 search_strategy,
269 evaluator,
270 default_concurrency,
271 } = self;
272 let evaluator_for_stream = evaluator.clone();
273
274 let new_stream = stream
275 .map_ok(move |query: Query<states::Retrieved>| {
276 let answerer = Arc::clone(&answerer);
277 let span = tracing::info_span!("then_answer", query = ?query);
278 let evaluator_for_stream = evaluator_for_stream.clone();
279
280 tokio::spawn(
281 async move {
282 tracing::debug!(answerer = answerer.name(), "Answering query");
283 let result = answerer.answer(query).await?;
284
285 if let Some(evaluator) = evaluator_for_stream.as_ref() {
286 evaluator.evaluate(result.clone().into()).await?;
287 Ok(result)
288 } else {
289 Ok(result)
290 }
291 }
292 .instrument(span.or_current()),
293 )
294 .err_into::<anyhow::Error>()
295 })
296 .try_buffer_unordered(default_concurrency)
297 .map(|x| x.and_then(|x| x));
298
299 Pipeline {
300 stream: new_stream.boxed().into(),
301 search_strategy,
302 query_sender,
303 evaluator,
304 default_concurrency,
305 }
306 }
307}
308
309impl<STRATEGY: SearchStrategy> Pipeline<'_, STRATEGY, states::Answered> {
310 #[tracing::instrument(skip_all, name = "query_pipeline.query")]
316 pub async fn query(
317 mut self,
318 query: impl Into<Query<states::Pending>>,
319 ) -> Result<Query<states::Answered>> {
320 tracing::debug!("Sending query");
321 let now = std::time::Instant::now();
322
323 self.query_sender.send(Ok(query.into())).await?;
324
325 let answer = self.stream.try_next().await?.ok_or_else(|| {
326 anyhow::anyhow!("Pipeline did not receive a response from the query stream")
327 });
328
329 let elapsed_in_seconds = now.elapsed().as_secs();
330 tracing::warn!(
331 elapsed_in_seconds,
332 "Answered query in {} seconds",
333 elapsed_in_seconds
334 );
335
336 answer
337 }
338
339 #[tracing::instrument(skip_all, name = "query_pipeline.query_mut")]
348 pub async fn query_mut(
349 &mut self,
350 query: impl Into<Query<states::Pending>>,
351 ) -> Result<Query<states::Answered>> {
352 tracing::warn!("Sending query");
353 let now = std::time::Instant::now();
354
355 self.query_sender.send(Ok(query.into())).await?;
356
357 let answer = self
358 .stream
359 .by_ref()
360 .take(1)
361 .try_next()
362 .await?
363 .ok_or_else(|| {
364 anyhow::anyhow!("Pipeline did not receive a response from the query stream")
365 });
366
367 tracing::debug!(?answer, "Received an answer");
368
369 let elapsed_in_seconds = now.elapsed().as_secs();
370 tracing::warn!(
371 elapsed_in_seconds,
372 "Answered query in {} seconds",
373 elapsed_in_seconds
374 );
375
376 answer
377 }
378
379 #[tracing::instrument(skip_all, name = "query_pipeline.query_all")]
386 pub async fn query_all(
387 self,
388 queries: Vec<impl Into<Query<states::Pending>> + Clone>,
389 ) -> Result<Vec<Query<states::Answered>>> {
390 tracing::warn!("Sending queries");
391 let now = std::time::Instant::now();
392
393 let Pipeline {
394 query_sender,
395 mut stream,
396 ..
397 } = self;
398
399 for query in &queries {
400 query_sender.send(Ok(query.clone().into())).await?;
401 }
402 tracing::info!("All queries sent");
403
404 let mut results = vec![];
405 while let Some(result) = stream.try_next().await? {
406 tracing::debug!(?result, "Received an answer");
407 results.push(result);
408 if results.len() == queries.len() {
409 break;
410 }
411 }
412
413 let elapsed_in_seconds = now.elapsed().as_secs();
414 tracing::warn!(
415 num_queries = queries.len(),
416 elapsed_in_seconds,
417 "Answered all queries in {} seconds",
418 elapsed_in_seconds
419 );
420 Ok(results)
421 }
422}
423
424#[cfg(test)]
425mod test {
426 use swiftide_core::{
427 MockAnswer, MockTransformQuery, MockTransformResponse, querying::search_strategies,
428 };
429
430 use super::*;
431
432 #[tokio::test]
433 async fn test_closures_in_each_step() {
434 let pipeline = Pipeline::default()
435 .then_transform_query(move |query: Query<states::Pending>| Ok(query))
436 .then_retrieve(
437 move |_: &search_strategies::SimilaritySingleEmbedding,
438 query: Query<states::Pending>| {
439 Ok(query.retrieved_documents(vec![]))
440 },
441 )
442 .then_transform_response(Ok)
443 .then_answer(move |query: Query<states::Retrieved>| Ok(query.answered("Ok")));
444 let response = pipeline.query("What").await.unwrap();
445 assert_eq!(response.answer(), "Ok");
446 }
447
448 #[tokio::test]
449 async fn test_all_steps_should_accept_dyn_box() {
450 let mut query_transformer = MockTransformQuery::new();
451 query_transformer.expect_transform_query().returning(Ok);
452
453 let mut response_transformer = MockTransformResponse::new();
454 response_transformer
455 .expect_transform_response()
456 .returning(Ok);
457 let mut answer_transformer = MockAnswer::new();
458 answer_transformer
459 .expect_answer()
460 .returning(|query| Ok(query.answered("OK")));
461
462 let pipeline = Pipeline::default()
463 .then_transform_query(Box::new(query_transformer) as Box<dyn TransformQuery>)
464 .then_retrieve(
465 |_: &search_strategies::SimilaritySingleEmbedding,
466 query: Query<states::Pending>| {
467 Ok(query.retrieved_documents(vec![]))
468 },
469 )
470 .then_transform_response(Box::new(response_transformer) as Box<dyn TransformResponse>)
471 .then_answer(Box::new(answer_transformer) as Box<dyn Answer>);
472 let response = pipeline.query("What").await.unwrap();
473 assert_eq!(response.answer(), "OK");
474 }
475
476 #[tokio::test]
477 async fn test_reuse_with_query_mut() {
478 let mut pipeline = Pipeline::default()
479 .then_transform_query(move |query: Query<states::Pending>| Ok(query))
480 .then_retrieve(
481 move |_: &search_strategies::SimilaritySingleEmbedding,
482 query: Query<states::Pending>| {
483 Ok(query.retrieved_documents(vec![]))
484 },
485 )
486 .then_transform_response(Ok)
487 .then_answer(move |query: Query<states::Retrieved>| Ok(query.answered("Ok")));
488
489 let response = pipeline.query_mut("What").await.unwrap();
490 assert_eq!(response.answer(), "Ok");
491 let response = pipeline.query_mut("What").await.unwrap();
492 assert_eq!(response.answer(), "Ok");
493 }
494}