swiftide_core/search_strategies/
custom_strategy.rs

1//! Implements a flexible vector search strategy framework using closure-based configuration.
2//! Supports both synchronous and asynchronous query generation for different retrieval backends.
3
4use crate::querying::{self, Query, states};
5use anyhow::{Result, anyhow};
6use std::future::Future;
7use std::marker::PhantomData;
8use std::pin::Pin;
9use std::sync::Arc;
10
11// TODO: Should be possible to remove the static bounds and allow Q as borrowed with some fu
12
13// Function type for generating retriever-specific queries
14type QueryGenerator<Q> = Arc<dyn Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync>;
15
16// Function type for async query generation
17type AsyncQueryGenerator<Q> = Arc<
18    dyn Fn(&Query<states::Pending>) -> Pin<Box<dyn Future<Output = Result<Q>> + Send>>
19        + Send
20        + Sync,
21>;
22
23/// Implements the strategy pattern for vector similarity search, allowing retrieval backends
24/// to define custom query generation logic through closures.
25pub struct CustomStrategy<Q> {
26    query: Option<QueryGenerator<Q>>,
27    async_query: Option<AsyncQueryGenerator<Q>>,
28    _marker: PhantomData<Q>,
29}
30
31impl<Q: Send + Sync> querying::SearchStrategy for CustomStrategy<Q> {}
32
33impl<Q> Default for CustomStrategy<Q> {
34    fn default() -> Self {
35        Self {
36            query: None,
37            async_query: None,
38            _marker: PhantomData,
39        }
40    }
41}
42
43impl<Q> Clone for CustomStrategy<Q> {
44    fn clone(&self) -> Self {
45        Self {
46            query: self.query.clone(),
47            async_query: self.async_query.clone(),
48            _marker: PhantomData,
49        }
50    }
51}
52
53impl<Q: Send + Sync> CustomStrategy<Q> {
54    /// Creates a new strategy with a synchronous query generator.
55    pub fn from_query(
56        query: impl Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync + 'static,
57    ) -> Self {
58        Self {
59            query: Some(Arc::new(query)),
60            async_query: None,
61            _marker: PhantomData,
62        }
63    }
64
65    /// Creates a new strategy with an asynchronous query generator.
66    pub fn from_async_query<F>(
67        query: impl Fn(&Query<states::Pending>) -> F + Send + Sync + 'static,
68    ) -> Self
69    where
70        F: Future<Output = Result<Q>> + Send + 'static,
71    {
72        Self {
73            query: None,
74            async_query: Some(Arc::new(move |q| Box::pin(query(q)))),
75            _marker: PhantomData,
76        }
77    }
78
79    /// Generates a query using either the sync or async generator.
80    /// Returns error if no query generator is set.
81    ///
82    /// # Errors
83    /// Returns an error if:
84    /// * No query generator has been configured
85    /// * The configured query generator fails during query generation
86    pub async fn build_query(&self, query_node: &Query<states::Pending>) -> Result<Q> {
87        match (&self.query, &self.async_query) {
88            (Some(query_fn), _) => query_fn(query_node),
89            (_, Some(async_fn)) => async_fn(query_node).await,
90            _ => Err(anyhow!("No query function has been set.")),
91        }
92    }
93}