use std::sync::Arc;
use swiftide_core::{
prelude::*,
querying::{
search_strategies::SimilaritySingleEmbedding, states, Answer, Query, QueryStream, Retrieve,
SearchStrategy, TransformQuery, TransformResponse,
},
EvaluateQuery,
};
use tokio::sync::mpsc::Sender;
pub struct Pipeline<'stream, S: SearchStrategy = SimilaritySingleEmbedding, T = states::Pending> {
search_strategy: S,
stream: QueryStream<'stream, T>,
query_sender: Sender<Result<Query<states::Pending>>>,
evaluator: Option<Arc<Box<dyn EvaluateQuery>>>,
default_concurrency: usize,
}
impl Default for Pipeline<'_, SimilaritySingleEmbedding> {
fn default() -> Self {
let stream = QueryStream::default();
Self {
search_strategy: SimilaritySingleEmbedding::default(),
query_sender: stream
.sender
.clone()
.expect("Pipeline received stream without query entrypoint"),
stream,
evaluator: None,
default_concurrency: num_cpus::get(),
}
}
}
impl<'a, S: SearchStrategy> Pipeline<'a, S> {
#[must_use]
pub fn from_search_strategy(strategy: S) -> Pipeline<'a, S> {
let stream = QueryStream::default();
Pipeline {
search_strategy: strategy,
query_sender: stream
.sender
.clone()
.expect("Pipeline received stream without query entrypoint"),
stream,
evaluator: None,
default_concurrency: num_cpus::get(),
}
}
}
impl<'stream: 'static, S> Pipeline<'stream, S, states::Pending>
where
S: SearchStrategy,
{
#[must_use]
pub fn evaluate_with<T: ToOwned<Owned = impl EvaluateQuery + 'stream>>(
mut self,
evaluator: T,
) -> Self {
self.evaluator = Some(Arc::new(Box::new(evaluator.to_owned())));
self
}
#[must_use]
pub fn then_transform_query<T: ToOwned<Owned = impl TransformQuery + 'stream>>(
self,
transformer: T,
) -> Pipeline<'stream, S, states::Pending> {
let transformer = Arc::new(transformer.to_owned());
let Pipeline {
stream,
query_sender,
search_strategy,
evaluator,
default_concurrency,
} = self;
let new_stream = stream
.map_ok(move |query| {
let transformer = Arc::clone(&transformer);
let span = tracing::trace_span!("then_transform_query", query = ?query);
async move { transformer.transform_query(query).await }.instrument(span)
})
.try_buffer_unordered(default_concurrency);
Pipeline {
stream: new_stream.boxed().into(),
search_strategy,
query_sender,
evaluator,
default_concurrency,
}
}
}
impl<'stream: 'static, S: SearchStrategy + 'stream> Pipeline<'stream, S, states::Pending> {
#[must_use]
pub fn then_retrieve<T: ToOwned<Owned = impl Retrieve<S> + 'stream>>(
self,
retriever: T,
) -> Pipeline<'stream, S, states::Retrieved> {
let retriever = Arc::new(retriever.to_owned());
let Pipeline {
stream,
query_sender,
search_strategy,
evaluator,
default_concurrency,
} = self;
let strategy_for_stream = search_strategy.clone();
let evaluator_for_stream = evaluator.clone();
let new_stream = stream
.map_ok(move |query| {
let search_strategy = strategy_for_stream.clone();
let retriever = Arc::clone(&retriever);
let span = tracing::trace_span!("then_retrieve", query = ?query);
let evaluator_for_stream = evaluator_for_stream.clone();
async move {
let result = retriever.retrieve(&search_strategy, query).await?;
if let Some(evaluator) = evaluator_for_stream.as_ref() {
evaluator.evaluate(result.clone().into()).await?;
Ok(result)
} else {
Ok(result)
}
}
.instrument(span)
})
.try_buffer_unordered(default_concurrency);
Pipeline {
stream: new_stream.boxed().into(),
search_strategy: search_strategy.clone(),
query_sender,
evaluator,
default_concurrency,
}
}
}
impl<'stream: 'static, S: SearchStrategy> Pipeline<'stream, S, states::Retrieved> {
#[must_use]
pub fn then_transform_response<T: ToOwned<Owned = impl TransformResponse + 'stream>>(
self,
transformer: T,
) -> Pipeline<'stream, S, states::Retrieved> {
let transformer = Arc::new(transformer.to_owned());
let Pipeline {
stream,
query_sender,
search_strategy,
evaluator,
default_concurrency,
} = self;
let new_stream = stream
.map_ok(move |query| {
let transformer = Arc::clone(&transformer);
let span = tracing::trace_span!("then_transform_response", query = ?query);
async move { transformer.transform_response(query).await }.instrument(span)
})
.try_buffer_unordered(default_concurrency);
Pipeline {
stream: new_stream.boxed().into(),
search_strategy,
query_sender,
evaluator,
default_concurrency,
}
}
}
impl<'stream: 'static, S: SearchStrategy> Pipeline<'stream, S, states::Retrieved> {
#[must_use]
pub fn then_answer<T: ToOwned<Owned = impl Answer + 'stream>>(
self,
answerer: T,
) -> Pipeline<'stream, S, states::Answered> {
let answerer = Arc::new(answerer.to_owned());
let Pipeline {
stream,
query_sender,
search_strategy,
evaluator,
default_concurrency,
} = self;
let evaluator_for_stream = evaluator.clone();
let new_stream = stream
.map_ok(move |query: Query<states::Retrieved>| {
let answerer = Arc::clone(&answerer);
let span = tracing::trace_span!("then_answer", query = ?query);
let evaluator_for_stream = evaluator_for_stream.clone();
async move {
let result = answerer.answer(query).await?;
if let Some(evaluator) = evaluator_for_stream.as_ref() {
evaluator.evaluate(result.clone().into()).await?;
Ok(result)
} else {
Ok(result)
}
}
.instrument(span)
})
.try_buffer_unordered(default_concurrency);
Pipeline {
stream: new_stream.boxed().into(),
search_strategy,
query_sender,
evaluator,
default_concurrency,
}
}
}
impl<S: SearchStrategy> Pipeline<'_, S, states::Answered> {
pub async fn query(
mut self,
query: impl Into<Query<states::Pending>>,
) -> Result<Query<states::Answered>> {
self.query_sender.send(Ok(query.into())).await?;
self.stream.try_next().await?.ok_or_else(|| {
anyhow::anyhow!("Pipeline did not receive a response from the query stream")
})
}
pub async fn query_all(
self,
queries: Vec<impl Into<Query<states::Pending>> + Clone>,
) -> Result<Vec<Query<states::Answered>>> {
let Pipeline {
query_sender,
mut stream,
..
} = self;
for query in &queries {
query_sender.send(Ok(query.clone().into())).await?;
}
tracing::info!("All queries sent");
let mut results = vec![];
while let Some(result) = stream.try_next().await? {
tracing::debug!(?result, "Received an answer");
results.push(result);
if results.len() == queries.len() {
break;
}
}
Ok(results)
}
}
#[cfg(test)]
mod test {
use swiftide_core::querying::search_strategies;
use super::*;
#[tokio::test]
async fn test_closures_in_each_step() {
let pipeline = Pipeline::default()
.then_transform_query(move |query: Query<states::Pending>| Ok(query))
.then_retrieve(
move |_: &search_strategies::SimilaritySingleEmbedding,
query: Query<states::Pending>| {
Ok(query.retrieved_documents(vec![]))
},
)
.then_transform_response(Ok)
.then_answer(move |query: Query<states::Retrieved>| Ok(query.answered("Ok")));
let response = pipeline.query("What").await.unwrap();
assert_eq!(response.answer(), "Ok");
}
}