Skip to main content

traitclaw_rag/
rag_context.rs

1//! `RagContextManager` — automatic document retrieval and injection.
2//!
3//! Implements [`ContextManager`] from `traitclaw-core` so that a RAG pipeline
4//! integrates transparently with the `Agent` builder.
5//!
6//! On every `prepare()` call, the manager:
7//! 1. Extracts the last user message as the retrieval query
8//! 2. Retrieves relevant documents via the configured [`Retriever`]
9//! 3. Formats them with the configured [`GroundingStrategy`]
10//! 4. Prepends the grounded context as a **system message**
11//!
12//! # Example
13//!
14//! ```rust
15//! use traitclaw_rag::{Document, KeywordRetriever};
16//! use traitclaw_rag::rag_context::RagContextManager;
17//!
18//! # async fn example() -> traitclaw_core::Result<()> {
19//! let mut retriever = KeywordRetriever::new();
20//! retriever.add(Document::new("doc1", "Rust is a systems language."));
21//!
22//! let manager = RagContextManager::new(retriever);
23//! # Ok(())
24//! # }
25//! ```
26
27use async_trait::async_trait;
28use traitclaw_core::{
29    traits::context_manager::ContextManager,
30    types::{
31        agent_state::AgentState,
32        message::{Message, MessageRole},
33    },
34};
35
36use crate::{GroundingStrategy, PrependStrategy, Retriever};
37
38/// A [`ContextManager`] that retrieves documents and injects them as grounded context.
39pub struct RagContextManager<R: Retriever, G: GroundingStrategy = PrependStrategy> {
40    retriever: R,
41    grounding: G,
42    max_docs: usize,
43}
44
45impl<R: Retriever> RagContextManager<R, PrependStrategy> {
46    /// Create a new `RagContextManager` with `PrependStrategy` as the default grounding.
47    #[must_use]
48    pub fn new(retriever: R) -> Self {
49        Self {
50            retriever,
51            grounding: PrependStrategy,
52            max_docs: 5,
53        }
54    }
55}
56
57impl<R: Retriever, G: GroundingStrategy> RagContextManager<R, G> {
58    /// Set the grounding strategy used to format retrieved documents.
59    #[must_use]
60    pub fn with_grounding<G2: GroundingStrategy>(self, grounding: G2) -> RagContextManager<R, G2> {
61        RagContextManager {
62            retriever: self.retriever,
63            grounding,
64            max_docs: self.max_docs,
65        }
66    }
67
68    /// Maximum number of retrieved documents to inject (default: 5).
69    #[must_use]
70    pub fn with_max_docs(mut self, max_docs: usize) -> Self {
71        self.max_docs = max_docs;
72        self
73    }
74}
75
76#[async_trait]
77impl<R: Retriever, G: GroundingStrategy> ContextManager for RagContextManager<R, G> {
78    /// Prepare messages: retrieve docs for last user query, prepend as system message.
79    async fn prepare(
80        &self,
81        messages: &mut Vec<Message>,
82        _context_window: usize,
83        _state: &mut AgentState,
84    ) {
85        // Extract the last user message as the query
86        let query = messages
87            .iter()
88            .rev()
89            .find(|m| m.role == MessageRole::User)
90            .map(|m| m.content.clone())
91            .unwrap_or_default();
92
93        if query.is_empty() {
94            return;
95        }
96
97        // Retrieve relevant documents
98        let docs = match self.retriever.retrieve(&query, self.max_docs).await {
99            Ok(d) => d,
100            Err(_) => return, // fail silently — don't break the agent
101        };
102
103        if docs.is_empty() {
104            return;
105        }
106
107        // Format with grounding strategy
108        let grounded_context = self.grounding.ground(&docs);
109
110        if grounded_context.is_empty() {
111            return;
112        }
113
114        // Prepend as system message
115        messages.insert(
116            0,
117            Message {
118                role: MessageRole::System,
119                content: grounded_context,
120                tool_call_id: None,
121            },
122        );
123    }
124}
125
126// ─────────────────────────────────────────────────────────────────────────────
127// Tests
128// ─────────────────────────────────────────────────────────────────────────────
129
130#[cfg(test)]
131mod tests {
132    use std::sync::Arc;
133
134    use traitclaw_core::types::model_info::ModelTier;
135
136    use super::*;
137    use crate::{Document, KeywordRetriever};
138
139    fn user_msg(content: &str) -> Message {
140        Message {
141            role: MessageRole::User,
142            content: content.to_string(),
143            tool_call_id: None,
144        }
145    }
146
147    fn system_msg(content: &str) -> Message {
148        Message {
149            role: MessageRole::System,
150            content: content.to_string(),
151            tool_call_id: None,
152        }
153    }
154
155    fn make_retriever(docs: Vec<(&str, &str)>) -> KeywordRetriever {
156        let mut r = KeywordRetriever::new();
157        for (id, content) in docs {
158            r.add(Document::new(id, content));
159        }
160        r
161    }
162
163    fn state() -> AgentState {
164        AgentState::new(ModelTier::Large, 128_000)
165    }
166
167    #[tokio::test]
168    async fn test_rag_context_manager_prepends_grounding() {
169        // AC #5, #6: retriever returns docs → context prepended as system message
170        let retriever = make_retriever(vec![
171            ("d1", "Rust is a systems language"),
172            ("d2", "Rust has zero-cost abstractions"),
173            ("d3", "Rust ownership model"),
174        ]);
175
176        let manager = RagContextManager::new(retriever);
177        let mut messages = vec![user_msg("Tell me about Rust")];
178        let mut st = state();
179        manager.prepare(&mut messages, 128_000, &mut st).await;
180
181        // First message should be a system message with context
182        assert_eq!(messages[0].role, MessageRole::System);
183        assert!(
184            messages[0].content.contains("Rust"),
185            "context should mention Rust"
186        );
187        // Original user message preserved at index 1
188        assert_eq!(messages[1].role, MessageRole::User);
189    }
190
191    #[tokio::test]
192    async fn test_rag_no_relevant_docs_unchanged() {
193        // AC #7: no relevant docs → context unchanged
194        let retriever = make_retriever(vec![("d1", "Python is great for data science")]);
195
196        let manager = RagContextManager::new(retriever);
197        let mut messages = vec![user_msg("Tell me about quantum computing")];
198        let mut st = state();
199        manager.prepare(&mut messages, 128_000, &mut st).await;
200
201        // No system message prepended
202        assert_eq!(messages.len(), 1);
203        assert_eq!(messages[0].role, MessageRole::User);
204    }
205
206    #[tokio::test]
207    async fn test_rag_empty_messages_unchanged() {
208        let retriever = make_retriever(vec![("d1", "some content")]);
209        let manager = RagContextManager::new(retriever);
210        let mut messages: Vec<Message> = vec![];
211        let mut st = state();
212        manager.prepare(&mut messages, 128_000, &mut st).await;
213        assert!(messages.is_empty());
214    }
215
216    #[tokio::test]
217    async fn test_rag_max_docs_limits_injection() {
218        // AC #3: with_max_docs(1) → only 1 doc injected
219        let retriever = make_retriever(vec![
220            ("d1", "Rust systems programming"),
221            ("d2", "Rust async programming"),
222            ("d3", "Rust embedded programming"),
223        ]);
224
225        let manager = RagContextManager::new(retriever).with_max_docs(1);
226        let mut messages = vec![user_msg("Rust programming")];
227        let mut st = state();
228        manager.prepare(&mut messages, 128_000, &mut st).await;
229
230        assert_eq!(messages[0].role, MessageRole::System);
231        // PrependStrategy format: "[1] content\n\n[2] content..."
232        // With max_docs=1 there should only be [1] and no [2]
233        assert!(
234            messages[0].content.contains("[1]"),
235            "should have first citation"
236        );
237        assert!(
238            !messages[0].content.contains("[2]"),
239            "should not have second citation with max_docs=1"
240        );
241    }
242
243    #[tokio::test]
244    async fn test_rag_user_message_found_among_others() {
245        // query is extracted from last user message even when system messages present
246        let retriever = make_retriever(vec![("d1", "Rust systems programming")]);
247        let manager = RagContextManager::new(retriever);
248
249        let mut messages = vec![
250            system_msg("You are a helpful assistant"),
251            user_msg("Tell me about Rust"),
252        ];
253        let mut st = state();
254        manager.prepare(&mut messages, 128_000, &mut st).await;
255
256        // System context prepended (now at index 0, original system at 1)
257        assert!(messages.len() >= 2);
258        assert_eq!(messages[0].role, MessageRole::System);
259    }
260
261    #[tokio::test]
262    async fn test_rag_implements_context_manager_trait() {
263        // Can be used as Arc<dyn ContextManager>
264        let retriever = KeywordRetriever::new();
265        let manager = RagContextManager::new(retriever);
266        let _: Arc<dyn ContextManager> = Arc::new(manager);
267    }
268}