Skip to main content

wesichain_graph/
retriever_node.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use futures::stream::StreamExt;
5use wesichain_core::{
6    Document, Embedding, EmbeddingError, HasMetadataFilter, HasQuery, HasRetrievedDocs,
7    MetadataFilter, Runnable, SearchResult, StoreError, StreamEvent, VectorStore, WesichainError,
8};
9use wesichain_retrieval::Retriever;
10
11use crate::{GraphState, StateSchema, StateUpdate};
12
13/// Newtype wrapper to implement `Embedding` for `Arc<dyn Embedding>`.
14///
15/// Required by Rust's orphan rule: we can't implement a foreign trait (`Embedding`)
16/// on a foreign type (`Arc<dyn Embedding>`) directly. This wrapper delegates all
17/// methods and is used internally by [`RetrieverNode`].
18struct DynEmbedding(Arc<dyn Embedding>);
19
20#[async_trait]
21impl Embedding for DynEmbedding {
22    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
23        self.0.embed(text).await
24    }
25
26    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
27        self.0.embed_batch(texts).await
28    }
29
30    fn dimension(&self) -> usize {
31        self.0.dimension()
32    }
33}
34
35/// Newtype wrapper to implement `VectorStore` for `Arc<dyn VectorStore>`.
36///
37/// Same orphan-rule workaround as [`DynEmbedding`]. Used internally by
38/// [`RetrieverNode`] to bridge the `Arc<dyn VectorStore>` constructor parameter
39/// into the concrete `Retriever<E, V>` type.
40struct DynVectorStore(Arc<dyn VectorStore>);
41
42#[async_trait]
43impl VectorStore for DynVectorStore {
44    async fn add(&self, docs: Vec<Document>) -> Result<(), StoreError> {
45        self.0.add(docs).await
46    }
47
48    async fn search(
49        &self,
50        query_embedding: &[f32],
51        top_k: usize,
52        filter: Option<&MetadataFilter>,
53    ) -> Result<Vec<SearchResult>, StoreError> {
54        self.0.search(query_embedding, top_k, filter).await
55    }
56
57    async fn delete(&self, ids: &[String]) -> Result<(), StoreError> {
58        self.0.delete(ids).await
59    }
60}
61
62pub struct RetrieverNode {
63    retriever: Retriever<DynEmbedding, DynVectorStore>,
64    top_k: usize,
65    score_threshold: Option<f32>,
66}
67
68impl RetrieverNode {
69    pub fn new(
70        embedder: Arc<dyn Embedding>,
71        store: Arc<dyn VectorStore>,
72        top_k: usize,
73        score_threshold: Option<f32>,
74    ) -> Self {
75        Self {
76            retriever: Retriever::new(DynEmbedding(embedder), DynVectorStore(store)),
77            top_k,
78            score_threshold,
79        }
80    }
81
82    fn apply_threshold(&self, mut results: Vec<SearchResult>) -> Vec<SearchResult> {
83        if let Some(threshold) = self.score_threshold {
84            results.retain(|res| res.score >= threshold);
85        }
86        results
87    }
88}
89
90#[async_trait]
91impl<S> Runnable<GraphState<S>, StateUpdate<S>> for RetrieverNode
92where
93    S: StateSchema<Update = S> + HasQuery + HasRetrievedDocs + HasMetadataFilter,
94{
95    async fn invoke(&self, input: GraphState<S>) -> Result<StateUpdate<S>, WesichainError> {
96        let query = input.data.query();
97        let filter = input.data.metadata_filter();
98        let results = self
99            .retriever
100            .retrieve(query, self.top_k, filter.as_ref())
101            .await
102            .map_err(|err| WesichainError::Custom(err.to_string()))?;
103        let results = self.apply_threshold(results);
104        let docs = results.into_iter().map(|res| res.document).collect();
105
106        let mut state = input;
107        state.data.set_retrieved_docs(docs);
108        Ok(StateUpdate::new(state.data))
109    }
110
111    fn stream(
112        &self,
113        _input: GraphState<S>,
114    ) -> futures::stream::BoxStream<'_, Result<StreamEvent, WesichainError>> {
115        futures::stream::empty().boxed()
116    }
117}