Skip to main content

traitclaw_rag/
hybrid.rs

1//! Hybrid retrieval combining keyword and embedding-based search.
2//!
3//! [`HybridRetriever`] combines a [`KeywordRetriever`] and any [`Retriever`] (e.g.,
4//! [`EmbeddingRetriever`]) with configurable score weighting, then re-ranks results.
5//!
6//! Also provides enhanced grounding strategies:
7//! - [`CitationStrategy`] — formats docs with citation numbers and source IDs
8//! - [`ContextWindowStrategy`] — limits injected context to a token budget
9//!
10//! [`EmbeddingRetriever`]: crate::embedding::EmbeddingRetriever
11
12use async_trait::async_trait;
13
14use crate::{Document, GroundingStrategy, KeywordRetriever, Retriever};
15
16// ─────────────────────────────────────────────────────────────────────────────
17// HybridRetriever
18// ─────────────────────────────────────────────────────────────────────────────
19
20/// Combines keyword and semantic retrieval with configurable weighting.
21///
22/// Results from both retrievers are merged, normalized, and re-ranked by a
23/// weighted sum of their individual scores.
24///
25/// # Example
26///
27/// ```rust
28/// use traitclaw_rag::{Document, KeywordRetriever, Retriever};
29/// use traitclaw_rag::hybrid::HybridRetriever;
30///
31/// # async fn example() -> traitclaw_core::Result<()> {
32/// let mut keyword = KeywordRetriever::new();
33/// keyword.add(Document::new("doc1", "Rust is fast"));
34///
35/// // Use keyword retriever as both sides for this example
36/// let hybrid = HybridRetriever::new(keyword, KeywordRetriever::new());
37/// let results = hybrid.retrieve("Rust", 5).await?;
38/// # Ok(())
39/// # }
40/// ```
41pub struct HybridRetriever<E: Retriever> {
42    keyword: KeywordRetriever,
43    embedding: E,
44    keyword_weight: f64,
45    embedding_weight: f64,
46}
47
48impl<E: Retriever> HybridRetriever<E> {
49    /// Create a new `HybridRetriever` with default weights (0.3 keyword / 0.7 embedding).
50    #[must_use]
51    pub fn new(keyword: KeywordRetriever, embedding: E) -> Self {
52        Self {
53            keyword,
54            embedding,
55            keyword_weight: 0.3,
56            embedding_weight: 0.7,
57        }
58    }
59
60    /// Set custom score weights.
61    ///
62    /// Weights need not sum to 1.0 — they are used as multipliers.
63    ///
64    /// # Panics
65    ///
66    /// Panics if either weight is negative.
67    #[must_use]
68    pub fn with_weights(mut self, keyword_weight: f64, embedding_weight: f64) -> Self {
69        assert!(keyword_weight >= 0.0, "keyword_weight must be non-negative");
70        assert!(
71            embedding_weight >= 0.0,
72            "embedding_weight must be non-negative"
73        );
74        self.keyword_weight = keyword_weight;
75        self.embedding_weight = embedding_weight;
76        self
77    }
78}
79
80#[async_trait]
81impl<E: Retriever> Retriever for HybridRetriever<E> {
82    /// Retrieve from both sources, merge, normalize, and re-rank.
83    async fn retrieve(&self, query: &str, limit: usize) -> traitclaw_core::Result<Vec<Document>> {
84        // Fetch from both — use a larger candidate set for better coverage
85        let candidate = (limit * 5).max(10);
86
87        let (kw_docs, emb_docs) = tokio::join!(
88            self.keyword.retrieve(query, candidate),
89            self.embedding.retrieve(query, candidate),
90        );
91        let kw_docs = kw_docs.unwrap_or_default();
92        let emb_docs = emb_docs.unwrap_or_default();
93
94        // Normalize scores within each set (0.0–1.0)
95        let kw_max = kw_docs
96            .iter()
97            .map(|d| d.score)
98            .fold(f64::NEG_INFINITY, f64::max);
99        let emb_max = emb_docs
100            .iter()
101            .map(|d| d.score)
102            .fold(f64::NEG_INFINITY, f64::max);
103
104        // Merge by doc id: combined_score = w_kw * norm_kw + w_emb * norm_emb
105        let mut scores: std::collections::HashMap<String, (f64, &Document)> =
106            std::collections::HashMap::new();
107
108        for doc in &kw_docs {
109            let norm = if kw_max > 0.0 {
110                doc.score / kw_max
111            } else {
112                0.0
113            };
114            scores
115                .entry(doc.id.clone())
116                .and_modify(|(s, _)| *s += self.keyword_weight * norm)
117                .or_insert((self.keyword_weight * norm, doc));
118        }
119
120        for doc in &emb_docs {
121            let norm = if emb_max > 0.0 {
122                doc.score / emb_max
123            } else {
124                0.0
125            };
126            scores
127                .entry(doc.id.clone())
128                .and_modify(|(s, _)| *s += self.embedding_weight * norm)
129                .or_insert((self.embedding_weight * norm, doc));
130        }
131
132        let mut combined: Vec<Document> = scores
133            .into_values()
134            .map(|(combined_score, doc)| {
135                let mut d = doc.clone();
136                d.score = combined_score;
137                d
138            })
139            .collect();
140
141        combined.sort_by(|a, b| {
142            b.score
143                .partial_cmp(&a.score)
144                .unwrap_or(std::cmp::Ordering::Equal)
145        });
146        combined.truncate(limit);
147
148        Ok(combined)
149    }
150}
151
152// ─────────────────────────────────────────────────────────────────────────────
153// CitationStrategy
154// ─────────────────────────────────────────────────────────────────────────────
155
156/// Grounding strategy that uses numbered citations with source IDs.
157///
158/// Format: `[1] content (Source: doc_id)\n`
159///
160/// # Example
161///
162/// ```rust
163/// use traitclaw_rag::{Document, GroundingStrategy};
164/// use traitclaw_rag::hybrid::CitationStrategy;
165///
166/// let docs = vec![Document::new("paper-42", "Important finding.")];
167/// let ctx = CitationStrategy.ground(&docs);
168/// assert!(ctx.contains("[1]"));
169/// assert!(ctx.contains("Source: paper-42"));
170/// ```
171pub struct CitationStrategy;
172
173impl GroundingStrategy for CitationStrategy {
174    fn ground(&self, documents: &[Document]) -> String {
175        if documents.is_empty() {
176            return String::new();
177        }
178        let mut ctx = String::from("Context:\n\n");
179        for (i, doc) in documents.iter().enumerate() {
180            use std::fmt::Write;
181            let _ = writeln!(ctx, "[{}] {} (Source: {})", i + 1, doc.content, doc.id);
182        }
183        ctx
184    }
185}
186
187// ─────────────────────────────────────────────────────────────────────────────
188// ContextWindowStrategy
189// ─────────────────────────────────────────────────────────────────────────────
190
191/// Grounding strategy that limits injected context to a token budget.
192///
193/// Uses a simple 4-chars-per-token heuristic. Documents are added in order
194/// until the budget would be exceeded.
195///
196/// # Example
197///
198/// ```rust
199/// use traitclaw_rag::{Document, GroundingStrategy};
200/// use traitclaw_rag::hybrid::ContextWindowStrategy;
201///
202/// let large_doc = Document::new("d1", &"word ".repeat(1000));
203/// let strategy = ContextWindowStrategy::new(50); // very small budget
204/// let ctx = strategy.ground(&[large_doc]);
205/// // Context is truncated to fit budget
206/// assert!(ctx.chars().count() < 400); // 50 tokens * 4 chars each
207/// ```
208pub struct ContextWindowStrategy {
209    max_tokens: usize,
210}
211
212impl ContextWindowStrategy {
213    /// Create a new `ContextWindowStrategy` with the given token budget.
214    ///
215    /// # Panics
216    ///
217    /// Panics if `max_tokens == 0`.
218    #[must_use]
219    pub fn new(max_tokens: usize) -> Self {
220        assert!(max_tokens > 0, "max_tokens must be > 0");
221        Self { max_tokens }
222    }
223}
224
225impl GroundingStrategy for ContextWindowStrategy {
226    fn ground(&self, documents: &[Document]) -> String {
227        if documents.is_empty() {
228            return String::new();
229        }
230
231        // 4 chars per token heuristic
232        let char_budget = self.max_tokens * 4;
233        let mut ctx = String::from("Context:\n\n");
234        let mut used = ctx.len();
235
236        for (i, doc) in documents.iter().enumerate() {
237            use std::fmt::Write;
238            let entry = format!("[{}] {}\n\n", i + 1, doc.content);
239
240            if used + entry.len() > char_budget {
241                // Try partial: trim doc content to fit
242                let available = char_budget.saturating_sub(used + 10); // header overhead
243                if available > 20 {
244                    let truncated: String = doc.content.chars().take(available).collect();
245                    let _ = write!(ctx, "[{}] {}…\n\n", i + 1, truncated);
246                }
247                break;
248            }
249
250            ctx.push_str(&entry);
251            used += entry.len();
252        }
253
254        ctx
255    }
256}
257
258// ─────────────────────────────────────────────────────────────────────────────
259// Tests
260// ─────────────────────────────────────────────────────────────────────────────
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::Document;
266
267    fn kw_retriever_with(docs: Vec<(&str, &str)>) -> KeywordRetriever {
268        let mut r = KeywordRetriever::new();
269        for (id, content) in docs {
270            r.add(Document::new(id, content));
271        }
272        r
273    }
274
275    // ── HybridRetriever ───────────────────────────────────────────────────────
276
277    #[tokio::test]
278    async fn test_hybrid_returns_from_keyword_source() {
279        // AC #7: hybrid returns results from keyword source
280        let kw = kw_retriever_with(vec![("k1", "Rust programming"), ("k2", "Python code")]);
281        let emb = KeywordRetriever::new(); // empty embedding side
282
283        let hybrid = HybridRetriever::new(kw, emb);
284        let results = hybrid.retrieve("Rust", 5).await.unwrap();
285
286        // Should still get keyword results even with empty embedding side
287        assert!(!results.is_empty(), "expected keyword results");
288        assert!(results.iter().any(|d| d.id == "k1"));
289    }
290
291    #[tokio::test]
292    async fn test_hybrid_merges_both_sources() {
293        // AC #7: hybrid returns from both sources
294        let kw = kw_retriever_with(vec![("k1", "Rust keyword hit")]);
295        let emb = kw_retriever_with(vec![("e1", "Rust embedding hit")]);
296
297        let hybrid = HybridRetriever::new(kw, emb);
298        let results = hybrid.retrieve("Rust hit", 10).await.unwrap();
299
300        let ids: Vec<_> = results.iter().map(|d| d.id.as_str()).collect();
301        assert!(
302            ids.contains(&"k1") || ids.contains(&"e1"),
303            "should contain results from both: {ids:?}"
304        );
305    }
306
307    #[tokio::test]
308    async fn test_hybrid_respects_limit() {
309        let kw = kw_retriever_with(vec![("k1", "Rust a"), ("k2", "Rust b"), ("k3", "Rust c")]);
310        let emb = kw_retriever_with(vec![("e1", "Rust d"), ("e2", "Rust e")]);
311        let hybrid = HybridRetriever::new(kw, emb);
312        let results = hybrid.retrieve("Rust", 2).await.unwrap();
313        assert!(results.len() <= 2);
314    }
315
316    #[tokio::test]
317    async fn test_hybrid_combined_score_sorted_desc() {
318        let kw = kw_retriever_with(vec![("k1", "Rust programming")]);
319        let emb = kw_retriever_with(vec![("e1", "Rust embedding search")]);
320        let hybrid = HybridRetriever::new(kw, emb);
321        let results = hybrid.retrieve("Rust", 10).await.unwrap();
322
323        for window in results.windows(2) {
324            assert!(window[0].score >= window[1].score);
325        }
326    }
327
328    // ── CitationStrategy ─────────────────────────────────────────────────────
329
330    #[test]
331    fn test_citation_strategy_format() {
332        // AC #5: formats as [1] content (Source: doc_id)
333        let docs = vec![
334            Document::new("paper-42", "Important finding."),
335            Document::new("blog-7", "Another insight."),
336        ];
337        let ctx = CitationStrategy.ground(&docs);
338        assert!(ctx.contains("[1]"));
339        assert!(ctx.contains("Source: paper-42"));
340        assert!(ctx.contains("[2]"));
341        assert!(ctx.contains("Source: blog-7"));
342    }
343
344    #[test]
345    fn test_citation_strategy_empty() {
346        assert!(CitationStrategy.ground(&[]).is_empty());
347    }
348
349    // ── ContextWindowStrategy ────────────────────────────────────────────────
350
351    #[test]
352    fn test_context_window_small_docs_fit() {
353        let docs = vec![
354            Document::new("d1", "Short."),
355            Document::new("d2", "Also short."),
356        ];
357        // 1000 tokens = 4000 chars — easily fits 2 tiny docs
358        let ctx = ContextWindowStrategy::new(1000).ground(&docs);
359        assert!(ctx.contains("[1]"));
360        assert!(ctx.contains("[2]"));
361    }
362
363    #[test]
364    fn test_context_window_truncates_large_doc() {
365        // AC #8: truncates when context exceeds budget
366        let large = "word ".repeat(500); // 2500 chars
367        let docs = vec![Document::new("big", &large)];
368
369        let strategy = ContextWindowStrategy::new(50); // 200 char budget
370        let ctx = strategy.ground(&docs);
371
372        // Output must be significantly smaller than the full doc
373        assert!(
374            ctx.chars().count() < 500,
375            "expected truncation, got {} chars",
376            ctx.chars().count()
377        );
378    }
379
380    #[test]
381    fn test_context_window_empty_docs() {
382        assert!(ContextWindowStrategy::new(100).ground(&[]).is_empty());
383    }
384}