traitclaw_rag/
rag_context.rs1use async_trait::async_trait;
28use traitclaw_core::{
29 traits::context_manager::ContextManager,
30 types::{
31 agent_state::AgentState,
32 message::{Message, MessageRole},
33 },
34};
35
36use crate::{GroundingStrategy, PrependStrategy, Retriever};
37
38pub struct RagContextManager<R: Retriever, G: GroundingStrategy = PrependStrategy> {
40 retriever: R,
41 grounding: G,
42 max_docs: usize,
43}
44
45impl<R: Retriever> RagContextManager<R, PrependStrategy> {
46 #[must_use]
48 pub fn new(retriever: R) -> Self {
49 Self {
50 retriever,
51 grounding: PrependStrategy,
52 max_docs: 5,
53 }
54 }
55}
56
57impl<R: Retriever, G: GroundingStrategy> RagContextManager<R, G> {
58 #[must_use]
60 pub fn with_grounding<G2: GroundingStrategy>(self, grounding: G2) -> RagContextManager<R, G2> {
61 RagContextManager {
62 retriever: self.retriever,
63 grounding,
64 max_docs: self.max_docs,
65 }
66 }
67
68 #[must_use]
70 pub fn with_max_docs(mut self, max_docs: usize) -> Self {
71 self.max_docs = max_docs;
72 self
73 }
74}
75
76#[async_trait]
77impl<R: Retriever, G: GroundingStrategy> ContextManager for RagContextManager<R, G> {
78 async fn prepare(
80 &self,
81 messages: &mut Vec<Message>,
82 _context_window: usize,
83 _state: &mut AgentState,
84 ) {
85 let query = messages
87 .iter()
88 .rev()
89 .find(|m| m.role == MessageRole::User)
90 .map(|m| m.content.clone())
91 .unwrap_or_default();
92
93 if query.is_empty() {
94 return;
95 }
96
97 let docs = match self.retriever.retrieve(&query, self.max_docs).await {
99 Ok(d) => d,
100 Err(_) => return, };
102
103 if docs.is_empty() {
104 return;
105 }
106
107 let grounded_context = self.grounding.ground(&docs);
109
110 if grounded_context.is_empty() {
111 return;
112 }
113
114 messages.insert(
116 0,
117 Message {
118 role: MessageRole::System,
119 content: grounded_context,
120 tool_call_id: None,
121 },
122 );
123 }
124}
125
126#[cfg(test)]
131mod tests {
132 use std::sync::Arc;
133
134 use traitclaw_core::types::model_info::ModelTier;
135
136 use super::*;
137 use crate::{Document, KeywordRetriever};
138
139 fn user_msg(content: &str) -> Message {
140 Message {
141 role: MessageRole::User,
142 content: content.to_string(),
143 tool_call_id: None,
144 }
145 }
146
147 fn system_msg(content: &str) -> Message {
148 Message {
149 role: MessageRole::System,
150 content: content.to_string(),
151 tool_call_id: None,
152 }
153 }
154
155 fn make_retriever(docs: Vec<(&str, &str)>) -> KeywordRetriever {
156 let mut r = KeywordRetriever::new();
157 for (id, content) in docs {
158 r.add(Document::new(id, content));
159 }
160 r
161 }
162
163 fn state() -> AgentState {
164 AgentState::new(ModelTier::Large, 128_000)
165 }
166
167 #[tokio::test]
168 async fn test_rag_context_manager_prepends_grounding() {
169 let retriever = make_retriever(vec![
171 ("d1", "Rust is a systems language"),
172 ("d2", "Rust has zero-cost abstractions"),
173 ("d3", "Rust ownership model"),
174 ]);
175
176 let manager = RagContextManager::new(retriever);
177 let mut messages = vec![user_msg("Tell me about Rust")];
178 let mut st = state();
179 manager.prepare(&mut messages, 128_000, &mut st).await;
180
181 assert_eq!(messages[0].role, MessageRole::System);
183 assert!(
184 messages[0].content.contains("Rust"),
185 "context should mention Rust"
186 );
187 assert_eq!(messages[1].role, MessageRole::User);
189 }
190
191 #[tokio::test]
192 async fn test_rag_no_relevant_docs_unchanged() {
193 let retriever = make_retriever(vec![("d1", "Python is great for data science")]);
195
196 let manager = RagContextManager::new(retriever);
197 let mut messages = vec![user_msg("Tell me about quantum computing")];
198 let mut st = state();
199 manager.prepare(&mut messages, 128_000, &mut st).await;
200
201 assert_eq!(messages.len(), 1);
203 assert_eq!(messages[0].role, MessageRole::User);
204 }
205
206 #[tokio::test]
207 async fn test_rag_empty_messages_unchanged() {
208 let retriever = make_retriever(vec![("d1", "some content")]);
209 let manager = RagContextManager::new(retriever);
210 let mut messages: Vec<Message> = vec![];
211 let mut st = state();
212 manager.prepare(&mut messages, 128_000, &mut st).await;
213 assert!(messages.is_empty());
214 }
215
216 #[tokio::test]
217 async fn test_rag_max_docs_limits_injection() {
218 let retriever = make_retriever(vec![
220 ("d1", "Rust systems programming"),
221 ("d2", "Rust async programming"),
222 ("d3", "Rust embedded programming"),
223 ]);
224
225 let manager = RagContextManager::new(retriever).with_max_docs(1);
226 let mut messages = vec![user_msg("Rust programming")];
227 let mut st = state();
228 manager.prepare(&mut messages, 128_000, &mut st).await;
229
230 assert_eq!(messages[0].role, MessageRole::System);
231 assert!(
234 messages[0].content.contains("[1]"),
235 "should have first citation"
236 );
237 assert!(
238 !messages[0].content.contains("[2]"),
239 "should not have second citation with max_docs=1"
240 );
241 }
242
243 #[tokio::test]
244 async fn test_rag_user_message_found_among_others() {
245 let retriever = make_retriever(vec![("d1", "Rust systems programming")]);
247 let manager = RagContextManager::new(retriever);
248
249 let mut messages = vec![
250 system_msg("You are a helpful assistant"),
251 user_msg("Tell me about Rust"),
252 ];
253 let mut st = state();
254 manager.prepare(&mut messages, 128_000, &mut st).await;
255
256 assert!(messages.len() >= 2);
258 assert_eq!(messages[0].role, MessageRole::System);
259 }
260
261 #[tokio::test]
262 async fn test_rag_implements_context_manager_trait() {
263 let retriever = KeywordRetriever::new();
265 let manager = RagContextManager::new(retriever);
266 let _: Arc<dyn ContextManager> = Arc::new(manager);
267 }
268}