wesichain_graph/
retriever_node.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use futures::stream::StreamExt;
5use wesichain_core::{
6 Document, Embedding, EmbeddingError, HasMetadataFilter, HasQuery, HasRetrievedDocs,
7 MetadataFilter, Runnable, SearchResult, StoreError, StreamEvent, VectorStore, WesichainError,
8};
9use wesichain_retrieval::Retriever;
10
11use crate::{GraphState, StateSchema, StateUpdate};
12
13struct DynEmbedding(Arc<dyn Embedding>);
19
20#[async_trait]
21impl Embedding for DynEmbedding {
22 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
23 self.0.embed(text).await
24 }
25
26 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
27 self.0.embed_batch(texts).await
28 }
29
30 fn dimension(&self) -> usize {
31 self.0.dimension()
32 }
33}
34
35struct DynVectorStore(Arc<dyn VectorStore>);
41
42#[async_trait]
43impl VectorStore for DynVectorStore {
44 async fn add(&self, docs: Vec<Document>) -> Result<(), StoreError> {
45 self.0.add(docs).await
46 }
47
48 async fn search(
49 &self,
50 query_embedding: &[f32],
51 top_k: usize,
52 filter: Option<&MetadataFilter>,
53 ) -> Result<Vec<SearchResult>, StoreError> {
54 self.0.search(query_embedding, top_k, filter).await
55 }
56
57 async fn delete(&self, ids: &[String]) -> Result<(), StoreError> {
58 self.0.delete(ids).await
59 }
60}
61
62pub struct RetrieverNode {
63 retriever: Retriever<DynEmbedding, DynVectorStore>,
64 top_k: usize,
65 score_threshold: Option<f32>,
66}
67
68impl RetrieverNode {
69 pub fn new(
70 embedder: Arc<dyn Embedding>,
71 store: Arc<dyn VectorStore>,
72 top_k: usize,
73 score_threshold: Option<f32>,
74 ) -> Self {
75 Self {
76 retriever: Retriever::new(DynEmbedding(embedder), DynVectorStore(store)),
77 top_k,
78 score_threshold,
79 }
80 }
81
82 fn apply_threshold(&self, mut results: Vec<SearchResult>) -> Vec<SearchResult> {
83 if let Some(threshold) = self.score_threshold {
84 results.retain(|res| res.score >= threshold);
85 }
86 results
87 }
88}
89
90#[async_trait]
91impl<S> Runnable<GraphState<S>, StateUpdate<S>> for RetrieverNode
92where
93 S: StateSchema<Update = S> + HasQuery + HasRetrievedDocs + HasMetadataFilter,
94{
95 async fn invoke(&self, input: GraphState<S>) -> Result<StateUpdate<S>, WesichainError> {
96 let query = input.data.query();
97 let filter = input.data.metadata_filter();
98 let results = self
99 .retriever
100 .retrieve(query, self.top_k, filter.as_ref())
101 .await
102 .map_err(|err| WesichainError::Custom(err.to_string()))?;
103 let results = self.apply_threshold(results);
104 let docs = results.into_iter().map(|res| res.document).collect();
105
106 let mut state = input;
107 state.data.set_retrieved_docs(docs);
108 Ok(StateUpdate::new(state.data))
109 }
110
111 fn stream(
112 &self,
113 _input: GraphState<S>,
114 ) -> futures::stream::BoxStream<'_, Result<StreamEvent, WesichainError>> {
115 futures::stream::empty().boxed()
116 }
117}