swiftide_core/search_strategies/
custom_strategy.rs1use crate::querying::{self, states, Query};
5use anyhow::{anyhow, Result};
6use std::future::Future;
7use std::marker::PhantomData;
8use std::pin::Pin;
9use std::sync::Arc;
10
11type QueryGenerator<Q> = Arc<dyn Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync>;
15
16type AsyncQueryGenerator<Q> = Arc<
18    dyn Fn(&Query<states::Pending>) -> Pin<Box<dyn Future<Output = Result<Q>> + Send>>
19        + Send
20        + Sync,
21>;
22
23pub struct CustomStrategy<Q> {
26    query: Option<QueryGenerator<Q>>,
27    async_query: Option<AsyncQueryGenerator<Q>>,
28    _marker: PhantomData<Q>,
29}
30
31impl<Q: Send + Sync> querying::SearchStrategy for CustomStrategy<Q> {}
32
33impl<Q> Default for CustomStrategy<Q> {
34    fn default() -> Self {
35        Self {
36            query: None,
37            async_query: None,
38            _marker: PhantomData,
39        }
40    }
41}
42
43impl<Q> Clone for CustomStrategy<Q> {
44    fn clone(&self) -> Self {
45        Self {
46            query: self.query.clone(),
47            async_query: self.async_query.clone(),
48            _marker: PhantomData,
49        }
50    }
51}
52
53impl<Q: Send + Sync> CustomStrategy<Q> {
54    pub fn from_query(
56        query: impl Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync + 'static,
57    ) -> Self {
58        Self {
59            query: Some(Arc::new(query)),
60            async_query: None,
61            _marker: PhantomData,
62        }
63    }
64
65    pub fn from_async_query<F>(
67        query: impl Fn(&Query<states::Pending>) -> F + Send + Sync + 'static,
68    ) -> Self
69    where
70        F: Future<Output = Result<Q>> + Send + 'static,
71    {
72        Self {
73            query: None,
74            async_query: Some(Arc::new(move |q| Box::pin(query(q)))),
75            _marker: PhantomData,
76        }
77    }
78
79    pub async fn build_query(&self, query_node: &Query<states::Pending>) -> Result<Q> {
87        match (&self.query, &self.async_query) {
88            (Some(query_fn), _) => query_fn(query_node),
89            (_, Some(async_fn)) => async_fn(query_node).await,
90            _ => Err(anyhow!("No query function has been set.")),
91        }
92    }
93}