1use crate::error::Result;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7use ucm_core::BlockId;
8
9#[derive(Debug, Clone, Default)]
11pub struct RagSearchOptions {
12 pub limit: usize,
14 pub min_similarity: f32,
16 pub filter_block_ids: Option<HashSet<BlockId>>,
18 pub filter_roles: Option<HashSet<String>>,
20 pub filter_tags: Option<HashSet<String>>,
22 pub include_content: bool,
24}
25
26impl RagSearchOptions {
27 pub fn new() -> Self {
28 Self {
29 limit: 10,
30 min_similarity: 0.0,
31 filter_block_ids: None,
32 filter_roles: None,
33 filter_tags: None,
34 include_content: true,
35 }
36 }
37
38 pub fn with_limit(mut self, limit: usize) -> Self {
39 self.limit = limit;
40 self
41 }
42
43 pub fn with_min_similarity(mut self, threshold: f32) -> Self {
44 self.min_similarity = threshold;
45 self
46 }
47
48 pub fn with_roles(mut self, roles: impl IntoIterator<Item = String>) -> Self {
49 self.filter_roles = Some(roles.into_iter().collect());
50 self
51 }
52
53 pub fn with_tags(mut self, tags: impl IntoIterator<Item = String>) -> Self {
54 self.filter_tags = Some(tags.into_iter().collect());
55 self
56 }
57
58 pub fn with_block_ids(mut self, ids: impl IntoIterator<Item = BlockId>) -> Self {
59 self.filter_block_ids = Some(ids.into_iter().collect());
60 self
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct RagMatch {
67 pub block_id: BlockId,
69 pub similarity: f32,
71 pub content_preview: Option<String>,
73 pub semantic_role: Option<String>,
75 pub highlight_spans: Vec<(usize, usize)>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct RagSearchResults {
82 pub matches: Vec<RagMatch>,
84 pub query: String,
86 pub total_searched: usize,
88 pub execution_time_ms: u64,
90}
91
92impl RagSearchResults {
93 pub fn empty(query: String) -> Self {
94 Self {
95 matches: Vec::new(),
96 query,
97 total_searched: 0,
98 execution_time_ms: 0,
99 }
100 }
101
102 pub fn block_ids(&self) -> Vec<BlockId> {
103 self.matches.iter().map(|m| m.block_id).collect()
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct RagCapabilities {
110 pub supports_search: bool,
112 pub supports_embedding: bool,
114 pub supports_filtering: bool,
116 pub max_query_length: usize,
118 pub max_results: usize,
120}
121
122impl Default for RagCapabilities {
123 fn default() -> Self {
124 Self {
125 supports_search: true,
126 supports_embedding: false,
127 supports_filtering: true,
128 max_query_length: 1000,
129 max_results: 100,
130 }
131 }
132}
133
134#[async_trait]
136pub trait RagProvider: Send + Sync {
137 async fn search(&self, query: &str, options: RagSearchOptions) -> Result<RagSearchResults>;
139
140 async fn embed(&self, content: &str) -> Result<Vec<f32>> {
142 let _ = content;
143 Ok(Vec::new())
144 }
145
146 fn capabilities(&self) -> RagCapabilities;
148
149 fn name(&self) -> &str;
151}
152
153pub struct NullRagProvider;
155
156#[async_trait]
157impl RagProvider for NullRagProvider {
158 async fn search(&self, query: &str, _options: RagSearchOptions) -> Result<RagSearchResults> {
159 Ok(RagSearchResults::empty(query.to_string()))
160 }
161
162 fn capabilities(&self) -> RagCapabilities {
163 RagCapabilities {
164 supports_search: false,
165 supports_embedding: false,
166 supports_filtering: false,
167 max_query_length: 0,
168 max_results: 0,
169 }
170 }
171
172 fn name(&self) -> &str {
173 "null"
174 }
175}
176
177pub struct MockRagProvider {
179 results: Vec<RagMatch>,
180}
181
182impl MockRagProvider {
183 pub fn new() -> Self {
184 Self {
185 results: Vec::new(),
186 }
187 }
188
189 pub fn with_results(mut self, results: Vec<RagMatch>) -> Self {
190 self.results = results;
191 self
192 }
193
194 pub fn add_result(&mut self, block_id: BlockId, similarity: f32, preview: Option<&str>) {
195 self.results.push(RagMatch {
196 block_id,
197 similarity,
198 content_preview: preview.map(String::from),
199 semantic_role: None,
200 highlight_spans: Vec::new(),
201 });
202 }
203}
204
205impl Default for MockRagProvider {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211#[async_trait]
212impl RagProvider for MockRagProvider {
213 async fn search(&self, query: &str, options: RagSearchOptions) -> Result<RagSearchResults> {
214 let matches: Vec<_> = self
215 .results
216 .iter()
217 .filter(|m| m.similarity >= options.min_similarity)
218 .take(options.limit)
219 .cloned()
220 .collect();
221
222 Ok(RagSearchResults {
223 matches,
224 query: query.to_string(),
225 total_searched: self.results.len(),
226 execution_time_ms: 1,
227 })
228 }
229
230 fn capabilities(&self) -> RagCapabilities {
231 RagCapabilities::default()
232 }
233
234 fn name(&self) -> &str {
235 "mock"
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 fn block_id(s: &str) -> BlockId {
244 s.parse().unwrap_or_else(|_| {
245 let mut bytes = [0u8; 12];
247 let s_bytes = s.as_bytes();
248 for (i, b) in s_bytes.iter().enumerate() {
249 bytes[i % 12] ^= *b;
250 }
251 BlockId::from_bytes(bytes)
252 })
253 }
254
255 #[tokio::test]
256 async fn test_null_provider() {
257 let provider = NullRagProvider;
258 let result = provider
259 .search("test query", RagSearchOptions::new())
260 .await
261 .unwrap();
262
263 assert!(result.matches.is_empty());
264 assert_eq!(result.query, "test query");
265 }
266
267 #[tokio::test]
268 async fn test_mock_provider() {
269 let mut provider = MockRagProvider::new();
270 provider.add_result(block_id("blk_000000000001"), 0.9, Some("test content"));
271 provider.add_result(block_id("blk_000000000002"), 0.8, None);
272
273 let result = provider
274 .search("test", RagSearchOptions::new().with_limit(5))
275 .await
276 .unwrap();
277
278 assert_eq!(result.matches.len(), 2);
279 assert_eq!(result.matches[0].similarity, 0.9);
280 }
281
282 #[tokio::test]
283 async fn test_mock_provider_filtering() {
284 let mut provider = MockRagProvider::new();
285 provider.add_result(block_id("blk_000000000001"), 0.9, None);
286 provider.add_result(block_id("blk_000000000002"), 0.5, None);
287
288 let result = provider
289 .search("test", RagSearchOptions::new().with_min_similarity(0.7))
290 .await
291 .unwrap();
292
293 assert_eq!(result.matches.len(), 1);
294 assert_eq!(result.matches[0].similarity, 0.9);
295 }
296}