swiftide_core/search_strategies/
custom_strategy.rs

1//! Implements a flexible vector search strategy framework using closure-based configuration.
2//! Supports both synchronous and asynchronous query generation for different retrieval backends.
3
4use crate::querying::{self, states, Query};
5use anyhow::{anyhow, Result};
6use std::future::Future;
7use std::marker::PhantomData;
8use std::pin::Pin;
9use std::sync::Arc;
10
11// Function type for generating retriever-specific queries
12type QueryGenerator<Q> = Arc<dyn Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync>;
13
14// Function type for async query generation
15type AsyncQueryGenerator<Q> = Arc<
16    dyn Fn(&Query<states::Pending>) -> Pin<Box<dyn Future<Output = Result<Q>> + Send>>
17        + Send
18        + Sync,
19>;
20
21/// Implements the strategy pattern for vector similarity search, allowing retrieval backends
22/// to define custom query generation logic through closures.
23pub struct CustomStrategy<Q> {
24    query: Option<QueryGenerator<Q>>,
25    async_query: Option<AsyncQueryGenerator<Q>>,
26    _marker: PhantomData<Q>,
27}
28
29impl<Q: Send + Sync + 'static> querying::SearchStrategy for CustomStrategy<Q> {}
30
31impl<Q> Default for CustomStrategy<Q> {
32    fn default() -> Self {
33        Self {
34            query: None,
35            async_query: None,
36            _marker: PhantomData,
37        }
38    }
39}
40
41impl<Q> Clone for CustomStrategy<Q> {
42    fn clone(&self) -> Self {
43        Self {
44            query: self.query.clone(),
45            async_query: self.async_query.clone(),
46            _marker: PhantomData,
47        }
48    }
49}
50
51impl<Q: Send + Sync + 'static> CustomStrategy<Q> {
52    /// Creates a new strategy with a synchronous query generator.
53    pub fn from_query(
54        query: impl Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync + 'static,
55    ) -> Self {
56        Self {
57            query: Some(Arc::new(query)),
58            async_query: None,
59            _marker: PhantomData,
60        }
61    }
62
63    /// Creates a new strategy with an asynchronous query generator.
64    pub fn from_async_query<F>(
65        query: impl Fn(&Query<states::Pending>) -> F + Send + Sync + 'static,
66    ) -> Self
67    where
68        F: Future<Output = Result<Q>> + Send + 'static,
69    {
70        Self {
71            query: None,
72            async_query: Some(Arc::new(move |q| Box::pin(query(q)))),
73            _marker: PhantomData,
74        }
75    }
76
77    /// Generates a query using either the sync or async generator.
78    /// Returns error if no query generator is set.
79    ///
80    /// # Errors
81    /// Returns an error if:
82    /// * No query generator has been configured
83    /// * The configured query generator fails during query generation
84    pub async fn build_query(&self, query_node: &Query<states::Pending>) -> Result<Q> {
85        match (&self.query, &self.async_query) {
86            (Some(query_fn), _) => query_fn(query_node),
87            (_, Some(async_fn)) => async_fn(query_node).await,
88            _ => Err(anyhow!("No query function has been set.")),
89        }
90    }
91}