1use async_trait::async_trait;
13
14use crate::{Document, GroundingStrategy, KeywordRetriever, Retriever};
15
16pub struct HybridRetriever<E: Retriever> {
42 keyword: KeywordRetriever,
43 embedding: E,
44 keyword_weight: f64,
45 embedding_weight: f64,
46}
47
48impl<E: Retriever> HybridRetriever<E> {
49 #[must_use]
51 pub fn new(keyword: KeywordRetriever, embedding: E) -> Self {
52 Self {
53 keyword,
54 embedding,
55 keyword_weight: 0.3,
56 embedding_weight: 0.7,
57 }
58 }
59
60 #[must_use]
68 pub fn with_weights(mut self, keyword_weight: f64, embedding_weight: f64) -> Self {
69 assert!(keyword_weight >= 0.0, "keyword_weight must be non-negative");
70 assert!(
71 embedding_weight >= 0.0,
72 "embedding_weight must be non-negative"
73 );
74 self.keyword_weight = keyword_weight;
75 self.embedding_weight = embedding_weight;
76 self
77 }
78}
79
80#[async_trait]
81impl<E: Retriever> Retriever for HybridRetriever<E> {
82 async fn retrieve(&self, query: &str, limit: usize) -> traitclaw_core::Result<Vec<Document>> {
84 let candidate = (limit * 5).max(10);
86
87 let (kw_docs, emb_docs) = tokio::join!(
88 self.keyword.retrieve(query, candidate),
89 self.embedding.retrieve(query, candidate),
90 );
91 let kw_docs = kw_docs.unwrap_or_default();
92 let emb_docs = emb_docs.unwrap_or_default();
93
94 let kw_max = kw_docs
96 .iter()
97 .map(|d| d.score)
98 .fold(f64::NEG_INFINITY, f64::max);
99 let emb_max = emb_docs
100 .iter()
101 .map(|d| d.score)
102 .fold(f64::NEG_INFINITY, f64::max);
103
104 let mut scores: std::collections::HashMap<String, (f64, &Document)> =
106 std::collections::HashMap::new();
107
108 for doc in &kw_docs {
109 let norm = if kw_max > 0.0 {
110 doc.score / kw_max
111 } else {
112 0.0
113 };
114 scores
115 .entry(doc.id.clone())
116 .and_modify(|(s, _)| *s += self.keyword_weight * norm)
117 .or_insert((self.keyword_weight * norm, doc));
118 }
119
120 for doc in &emb_docs {
121 let norm = if emb_max > 0.0 {
122 doc.score / emb_max
123 } else {
124 0.0
125 };
126 scores
127 .entry(doc.id.clone())
128 .and_modify(|(s, _)| *s += self.embedding_weight * norm)
129 .or_insert((self.embedding_weight * norm, doc));
130 }
131
132 let mut combined: Vec<Document> = scores
133 .into_values()
134 .map(|(combined_score, doc)| {
135 let mut d = doc.clone();
136 d.score = combined_score;
137 d
138 })
139 .collect();
140
141 combined.sort_by(|a, b| {
142 b.score
143 .partial_cmp(&a.score)
144 .unwrap_or(std::cmp::Ordering::Equal)
145 });
146 combined.truncate(limit);
147
148 Ok(combined)
149 }
150}
151
152pub struct CitationStrategy;
172
173impl GroundingStrategy for CitationStrategy {
174 fn ground(&self, documents: &[Document]) -> String {
175 if documents.is_empty() {
176 return String::new();
177 }
178 let mut ctx = String::from("Context:\n\n");
179 for (i, doc) in documents.iter().enumerate() {
180 use std::fmt::Write;
181 let _ = writeln!(ctx, "[{}] {} (Source: {})", i + 1, doc.content, doc.id);
182 }
183 ctx
184 }
185}
186
187pub struct ContextWindowStrategy {
209 max_tokens: usize,
210}
211
212impl ContextWindowStrategy {
213 #[must_use]
219 pub fn new(max_tokens: usize) -> Self {
220 assert!(max_tokens > 0, "max_tokens must be > 0");
221 Self { max_tokens }
222 }
223}
224
225impl GroundingStrategy for ContextWindowStrategy {
226 fn ground(&self, documents: &[Document]) -> String {
227 if documents.is_empty() {
228 return String::new();
229 }
230
231 let char_budget = self.max_tokens * 4;
233 let mut ctx = String::from("Context:\n\n");
234 let mut used = ctx.len();
235
236 for (i, doc) in documents.iter().enumerate() {
237 use std::fmt::Write;
238 let entry = format!("[{}] {}\n\n", i + 1, doc.content);
239
240 if used + entry.len() > char_budget {
241 let available = char_budget.saturating_sub(used + 10); if available > 20 {
244 let truncated: String = doc.content.chars().take(available).collect();
245 let _ = write!(ctx, "[{}] {}…\n\n", i + 1, truncated);
246 }
247 break;
248 }
249
250 ctx.push_str(&entry);
251 used += entry.len();
252 }
253
254 ctx
255 }
256}
257
258#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::Document;
266
267 fn kw_retriever_with(docs: Vec<(&str, &str)>) -> KeywordRetriever {
268 let mut r = KeywordRetriever::new();
269 for (id, content) in docs {
270 r.add(Document::new(id, content));
271 }
272 r
273 }
274
275 #[tokio::test]
278 async fn test_hybrid_returns_from_keyword_source() {
279 let kw = kw_retriever_with(vec![("k1", "Rust programming"), ("k2", "Python code")]);
281 let emb = KeywordRetriever::new(); let hybrid = HybridRetriever::new(kw, emb);
284 let results = hybrid.retrieve("Rust", 5).await.unwrap();
285
286 assert!(!results.is_empty(), "expected keyword results");
288 assert!(results.iter().any(|d| d.id == "k1"));
289 }
290
291 #[tokio::test]
292 async fn test_hybrid_merges_both_sources() {
293 let kw = kw_retriever_with(vec![("k1", "Rust keyword hit")]);
295 let emb = kw_retriever_with(vec![("e1", "Rust embedding hit")]);
296
297 let hybrid = HybridRetriever::new(kw, emb);
298 let results = hybrid.retrieve("Rust hit", 10).await.unwrap();
299
300 let ids: Vec<_> = results.iter().map(|d| d.id.as_str()).collect();
301 assert!(
302 ids.contains(&"k1") || ids.contains(&"e1"),
303 "should contain results from both: {ids:?}"
304 );
305 }
306
307 #[tokio::test]
308 async fn test_hybrid_respects_limit() {
309 let kw = kw_retriever_with(vec![("k1", "Rust a"), ("k2", "Rust b"), ("k3", "Rust c")]);
310 let emb = kw_retriever_with(vec![("e1", "Rust d"), ("e2", "Rust e")]);
311 let hybrid = HybridRetriever::new(kw, emb);
312 let results = hybrid.retrieve("Rust", 2).await.unwrap();
313 assert!(results.len() <= 2);
314 }
315
316 #[tokio::test]
317 async fn test_hybrid_combined_score_sorted_desc() {
318 let kw = kw_retriever_with(vec![("k1", "Rust programming")]);
319 let emb = kw_retriever_with(vec![("e1", "Rust embedding search")]);
320 let hybrid = HybridRetriever::new(kw, emb);
321 let results = hybrid.retrieve("Rust", 10).await.unwrap();
322
323 for window in results.windows(2) {
324 assert!(window[0].score >= window[1].score);
325 }
326 }
327
328 #[test]
331 fn test_citation_strategy_format() {
332 let docs = vec![
334 Document::new("paper-42", "Important finding."),
335 Document::new("blog-7", "Another insight."),
336 ];
337 let ctx = CitationStrategy.ground(&docs);
338 assert!(ctx.contains("[1]"));
339 assert!(ctx.contains("Source: paper-42"));
340 assert!(ctx.contains("[2]"));
341 assert!(ctx.contains("Source: blog-7"));
342 }
343
344 #[test]
345 fn test_citation_strategy_empty() {
346 assert!(CitationStrategy.ground(&[]).is_empty());
347 }
348
349 #[test]
352 fn test_context_window_small_docs_fit() {
353 let docs = vec![
354 Document::new("d1", "Short."),
355 Document::new("d2", "Also short."),
356 ];
357 let ctx = ContextWindowStrategy::new(1000).ground(&docs);
359 assert!(ctx.contains("[1]"));
360 assert!(ctx.contains("[2]"));
361 }
362
363 #[test]
364 fn test_context_window_truncates_large_doc() {
365 let large = "word ".repeat(500); let docs = vec![Document::new("big", &large)];
368
369 let strategy = ContextWindowStrategy::new(50); let ctx = strategy.ground(&docs);
371
372 assert!(
374 ctx.chars().count() < 500,
375 "expected truncation, got {} chars",
376 ctx.chars().count()
377 );
378 }
379
380 #[test]
381 fn test_context_window_empty_docs() {
382 assert!(ContextWindowStrategy::new(100).ground(&[]).is_empty());
383 }
384}