Skip to main content

traitclaw_rag/
lib.rs

1//! RAG (Retrieval-Augmented Generation) pipeline for the `TraitClaw` AI agent framework.
2//!
3//! Provides a `Retriever` trait, grounding strategies, and a built-in
4//! `KeywordRetriever` with BM25-style scoring for text search.
5//!
6//! # Quick Start
7//!
8//! ```rust
9//! use traitclaw_rag::{Document, KeywordRetriever, Retriever, GroundingStrategy, PrependStrategy};
10//!
11//! # async fn example() -> traitclaw_core::Result<()> {
12//! let mut retriever = KeywordRetriever::new();
13//! retriever.add(Document::new("doc1", "Rust is a systems programming language"));
14//! retriever.add(Document::new("doc2", "Python is great for AI"));
15//!
16//! let docs = retriever.retrieve("Rust systems", 5).await?;
17//! assert!(!docs.is_empty());
18//!
19//! let strategy = PrependStrategy;
20//! let context = strategy.ground(&docs);
21//! assert!(context.contains("Rust"));
22//! # Ok(())
23//! # }
24//! ```
25
26#![deny(missing_docs)]
27#![allow(clippy::redundant_closure)]
28
29pub mod chunker;
30pub mod embedding;
31pub mod hybrid;
32pub mod rag_context;
33
34use async_trait::async_trait;
35use serde::{Deserialize, Serialize};
36
37pub use chunker::{Chunker, FixedSizeChunker, RecursiveChunker, SentenceChunker};
38pub use embedding::{EmbeddingProvider, EmbeddingRetriever};
39pub use hybrid::{CitationStrategy, ContextWindowStrategy, HybridRetriever};
40pub use rag_context::RagContextManager;
41
42/// A document for retrieval.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct Document {
45    /// Unique document identifier.
46    pub id: String,
47    /// Document content.
48    pub content: String,
49    /// Optional metadata.
50    pub metadata: Option<serde_json::Value>,
51    /// Relevance score (set by retriever).
52    pub score: f64,
53}
54
55impl Document {
56    /// Create a new document with no metadata and zero score.
57    #[must_use]
58    pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
59        Self {
60            id: id.into(),
61            content: content.into(),
62            metadata: None,
63            score: 0.0,
64        }
65    }
66}
67
68/// Trait for document retrieval.
69///
70/// Implement this for custom retrieval backends (vector DB, API, etc.).
71#[async_trait]
72pub trait Retriever: Send + Sync + 'static {
73    /// Retrieve relevant documents for a query.
74    async fn retrieve(&self, query: &str, limit: usize) -> traitclaw_core::Result<Vec<Document>>;
75}
76
77/// Strategy for grounding agent context with retrieved documents.
78pub trait GroundingStrategy: Send + Sync + 'static {
79    /// Convert retrieved documents into context text for the agent.
80    fn ground(&self, documents: &[Document]) -> String;
81}
82
83/// Simple grounding strategy that prepends documents as numbered context.
84pub struct PrependStrategy;
85
86impl GroundingStrategy for PrependStrategy {
87    fn ground(&self, documents: &[Document]) -> String {
88        if documents.is_empty() {
89            return String::new();
90        }
91        let mut ctx = String::from("Relevant context:\n\n");
92        for (i, doc) in documents.iter().enumerate() {
93            use std::fmt::Write;
94            let _ = write!(ctx, "[{}] {}\n\n", i + 1, doc.content);
95        }
96        ctx
97    }
98}
99
100/// BM25-style keyword retriever for in-memory text search.
101///
102/// Scores documents using term frequency and inverse document frequency.
103pub struct KeywordRetriever {
104    documents: Vec<Document>,
105}
106
107impl KeywordRetriever {
108    /// Create a new empty keyword retriever.
109    #[must_use]
110    pub fn new() -> Self {
111        Self {
112            documents: Vec::new(),
113        }
114    }
115
116    /// Add a document to the index.
117    pub fn add(&mut self, doc: Document) {
118        self.documents.push(doc);
119    }
120
121    /// Add multiple documents to the index.
122    pub fn add_many(&mut self, docs: impl IntoIterator<Item = Document>) {
123        self.documents.extend(docs);
124    }
125
126    /// Score a document against query terms using BM25-like TF scoring.
127    fn score(query_terms: &[String], content: &str) -> f64 {
128        let content_lower = content.to_lowercase();
129        let words: Vec<&str> = content_lower.split_whitespace().collect();
130        let doc_len = words.len() as f64;
131
132        if doc_len == 0.0 {
133            return 0.0;
134        }
135
136        let mut total_score = 0.0;
137        for term in query_terms {
138            let tf = words.iter().filter(|w| **w == term.as_str()).count() as f64;
139            // BM25-like: tf / (tf + 1.2 * (1 - 0.75 + 0.75 * doc_len / avg_len))
140            // Simplified: use tf / (tf + 1.0)
141            let score = tf / (tf + 1.0);
142            total_score += score;
143        }
144
145        total_score
146    }
147}
148
149impl Default for KeywordRetriever {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155#[async_trait]
156impl Retriever for KeywordRetriever {
157    async fn retrieve(&self, query: &str, limit: usize) -> traitclaw_core::Result<Vec<Document>> {
158        let terms: Vec<String> = query
159            .to_lowercase()
160            .split_whitespace()
161            .map(String::from)
162            .collect();
163
164        let mut scored: Vec<Document> = self
165            .documents
166            .iter()
167            .map(|doc| {
168                let mut d = doc.clone();
169                d.score = Self::score(&terms, &doc.content);
170                d
171            })
172            .filter(|d| d.score > 0.0)
173            .collect();
174
175        scored.sort_by(|a, b| {
176            b.score
177                .partial_cmp(&a.score)
178                .unwrap_or(std::cmp::Ordering::Equal)
179        });
180        scored.truncate(limit);
181
182        Ok(scored)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[tokio::test]
191    async fn test_keyword_retriever_basic() {
192        let mut r = KeywordRetriever::new();
193        r.add(Document::new("1", "Rust is a systems programming language"));
194        r.add(Document::new("2", "Python is great for data science"));
195        r.add(Document::new("3", "Rust has zero-cost abstractions"));
196
197        let results = r.retrieve("Rust programming", 10).await.unwrap();
198        assert!(!results.is_empty());
199        // "Rust" appears in doc 1 and 3
200        assert!(results.len() >= 2);
201        // Doc 1 should score higher (has both "rust" and "programming")
202        assert_eq!(results[0].id, "1");
203    }
204
205    #[tokio::test]
206    async fn test_keyword_retriever_empty_query() {
207        let mut r = KeywordRetriever::new();
208        r.add(Document::new("1", "Some content"));
209        let results = r.retrieve("", 10).await.unwrap();
210        assert!(results.is_empty());
211    }
212
213    #[tokio::test]
214    async fn test_keyword_retriever_no_match() {
215        let mut r = KeywordRetriever::new();
216        r.add(Document::new("1", "Hello world"));
217        let results = r.retrieve("quantum computing", 10).await.unwrap();
218        assert!(results.is_empty());
219    }
220
221    #[tokio::test]
222    async fn test_keyword_retriever_limit() {
223        let mut r = KeywordRetriever::new();
224        for i in 0..10 {
225            r.add(Document::new(format!("{i}"), format!("rust item {i}")));
226        }
227        let results = r.retrieve("rust", 3).await.unwrap();
228        assert_eq!(results.len(), 3);
229    }
230
231    #[test]
232    fn test_prepend_strategy() {
233        let docs = vec![
234            Document::new("1", "First doc"),
235            Document::new("2", "Second doc"),
236        ];
237        let ctx = PrependStrategy.ground(&docs);
238        assert!(ctx.contains("[1] First doc"));
239        assert!(ctx.contains("[2] Second doc"));
240    }
241
242    #[test]
243    fn test_prepend_strategy_empty() {
244        let ctx = PrependStrategy.ground(&[]);
245        assert!(ctx.is_empty());
246    }
247
248    #[test]
249    fn test_document_new() {
250        let doc = Document::new("id1", "content1");
251        assert_eq!(doc.id, "id1");
252        assert_eq!(doc.content, "content1");
253        assert!(doc.metadata.is_none());
254        assert!((doc.score - 0.0).abs() < f64::EPSILON);
255    }
256}