1use crate::{
4 chunk::{Chunk, Chunker, RecursiveChunker},
5 embed::{Embedder, MockEmbedder},
6 fusion::FusionStrategy,
7 index::{BM25Index, VectorStore},
8 rerank::{NoOpReranker, Reranker},
9 retrieve::{HybridRetriever, HybridRetrieverConfig, RetrievalResult},
10 Document, DocumentId, Error, Result,
11};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15const DEFAULT_EMBEDDING_DIM: usize = 384;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Citation {
21 pub id: usize,
23 pub document_id: DocumentId,
25 pub chunk_id: crate::ChunkId,
27 pub title: Option<String>,
29 pub url: Option<String>,
31 pub page: Option<usize>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ContextChunk {
38 pub content: String,
40 pub citation_id: usize,
42 pub retrieval_score: f32,
44 pub rerank_score: Option<f32>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct AssembledContext {
51 pub chunks: Vec<ContextChunk>,
53 pub total_tokens: usize,
55 pub citations: Vec<Citation>,
57}
58
59impl AssembledContext {
60 #[must_use]
62 pub fn new() -> Self {
63 Self { chunks: Vec::new(), total_tokens: 0, citations: Vec::new() }
64 }
65
66 pub fn add_chunk(&mut self, result: &RetrievalResult, citation_id: usize) {
68 let chunk = ContextChunk {
69 content: result.chunk.content.clone(),
70 citation_id,
71 retrieval_score: result.best_score(),
72 rerank_score: result.rerank_score,
73 };
74
75 self.total_tokens += result.chunk.content.len() / 4;
77 self.chunks.push(chunk);
78 }
79
80 pub fn add_citation(&mut self, result: &RetrievalResult) -> usize {
82 let id = self.citations.len() + 1;
83
84 let citation = Citation {
85 id,
86 document_id: result.chunk.document_id,
87 chunk_id: result.chunk.id,
88 title: result.chunk.metadata.title.clone(),
89 url: None, page: result.chunk.metadata.page,
91 };
92
93 self.citations.push(citation);
94 id
95 }
96
97 #[must_use]
99 pub fn format_with_citations(&self) -> String {
100 self.chunks
101 .iter()
102 .map(|c| format!("{} [{}]", c.content, c.citation_id))
103 .collect::<Vec<_>>()
104 .join("\n\n")
105 }
106
107 #[must_use]
109 pub fn format_plain(&self) -> String {
110 self.chunks.iter().map(|c| c.content.as_str()).collect::<Vec<_>>().join("\n\n")
111 }
112
113 #[must_use]
115 pub fn citation_list(&self) -> String {
116 self.citations
117 .iter()
118 .map(|c| {
119 let title = c.title.as_deref().unwrap_or("Untitled");
120 format!("[{}] {}", c.id, title)
121 })
122 .collect::<Vec<_>>()
123 .join("\n")
124 }
125
126 #[must_use]
128 pub fn len(&self) -> usize {
129 self.chunks.len()
130 }
131
132 #[must_use]
134 pub fn is_empty(&self) -> bool {
135 self.chunks.is_empty()
136 }
137}
138
139impl Default for AssembledContext {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145#[derive(Debug, Clone, Default, Serialize, Deserialize)]
147pub enum AssemblyStrategy {
148 #[default]
150 Sequential,
151 DocumentGrouped,
153 Interleaved,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ContextAssemblerConfig {
160 pub max_tokens: usize,
162 pub strategy: AssemblyStrategy,
164 pub include_citations: bool,
166}
167
168impl Default for ContextAssemblerConfig {
169 fn default() -> Self {
170 Self { max_tokens: 4096, strategy: AssemblyStrategy::Sequential, include_citations: true }
171 }
172}
173
174#[derive(Debug, Clone)]
176pub struct ContextAssembler {
177 config: ContextAssemblerConfig,
178}
179
180impl ContextAssembler {
181 #[must_use]
183 pub fn new(config: ContextAssemblerConfig) -> Self {
184 Self { config }
185 }
186
187 #[must_use]
189 pub fn with_max_tokens(max_tokens: usize) -> Self {
190 Self::new(ContextAssemblerConfig { max_tokens, ..Default::default() })
191 }
192
193 #[must_use]
195 pub fn assemble(&self, results: &[RetrievalResult]) -> AssembledContext {
196 match self.config.strategy {
197 AssemblyStrategy::Sequential => self.assemble_sequential(results),
198 AssemblyStrategy::DocumentGrouped => self.assemble_grouped(results),
199 AssemblyStrategy::Interleaved => self.assemble_interleaved(results),
200 }
201 }
202
203 fn assemble_sequential(&self, results: &[RetrievalResult]) -> AssembledContext {
204 let mut context = AssembledContext::new();
205 let mut remaining_tokens = self.config.max_tokens;
206
207 for result in results {
208 let chunk_tokens = result.chunk.content.len() / 4; if chunk_tokens > remaining_tokens {
211 break;
213 }
214
215 let citation_id =
216 if self.config.include_citations { context.add_citation(result) } else { 0 };
217
218 context.add_chunk(result, citation_id);
219 remaining_tokens = remaining_tokens.saturating_sub(chunk_tokens);
220 }
221
222 context
223 }
224
225 fn assemble_grouped(&self, results: &[RetrievalResult]) -> AssembledContext {
226 let mut by_doc: HashMap<DocumentId, Vec<&RetrievalResult>> = HashMap::new();
228 for result in results {
229 by_doc.entry(result.chunk.document_id).or_default().push(result);
230 }
231
232 let mut context = AssembledContext::new();
234 let mut remaining_tokens = self.config.max_tokens;
235
236 for (_, doc_results) in by_doc {
237 for result in doc_results {
238 let chunk_tokens = result.chunk.content.len() / 4;
239
240 if chunk_tokens > remaining_tokens {
241 break;
242 }
243
244 let citation_id =
245 if self.config.include_citations { context.add_citation(result) } else { 0 };
246
247 context.add_chunk(result, citation_id);
248 remaining_tokens = remaining_tokens.saturating_sub(chunk_tokens);
249 }
250 }
251
252 context
253 }
254
255 fn assemble_interleaved(&self, results: &[RetrievalResult]) -> AssembledContext {
256 self.assemble_sequential(results)
258 }
259}
260
261impl Default for ContextAssembler {
262 fn default() -> Self {
263 Self::new(ContextAssemblerConfig::default())
264 }
265}
266
267#[derive(Debug, Clone)]
269pub struct RagPipelineConfig {
270 pub chunk_size: usize,
272 pub chunk_overlap: usize,
274 pub embedding_dimension: usize,
276 pub retrieval: HybridRetrieverConfig,
278 pub context: ContextAssemblerConfig,
280}
281
282impl Default for RagPipelineConfig {
283 fn default() -> Self {
284 Self {
285 chunk_size: 512,
286 chunk_overlap: 50,
287 embedding_dimension: DEFAULT_EMBEDDING_DIM,
288 retrieval: HybridRetrieverConfig::default(),
289 context: ContextAssemblerConfig::default(),
290 }
291 }
292}
293
294pub struct RagPipeline<E: Embedder, R: Reranker> {
296 chunker: Box<dyn Chunker>,
298 embedder: E,
300 retriever: HybridRetriever<E>,
302 reranker: R,
304 assembler: ContextAssembler,
306 document_count: usize,
308}
309
310impl<E: Embedder + Clone, R: Reranker> RagPipeline<E, R> {
311 pub fn index_document(&mut self, document: &Document) -> Result<Vec<Chunk>> {
313 let mut chunks = self.chunker.chunk(document)?;
315
316 self.embedder.embed_chunks(&mut chunks)?;
318
319 for chunk in &chunks {
321 self.retriever.index(chunk.clone())?;
322 }
323
324 self.document_count += 1;
325 Ok(chunks)
326 }
327
328 pub fn index_documents(&mut self, documents: &[Document]) -> Result<usize> {
330 let mut total_chunks = 0;
331 for doc in documents {
332 let chunks = self.index_document(doc)?;
333 total_chunks += chunks.len();
334 }
335 Ok(total_chunks)
336 }
337
338 #[must_use]
340 pub fn document_count(&self) -> usize {
341 self.document_count
342 }
343
344 #[must_use]
346 pub fn chunk_count(&self) -> usize {
347 self.retriever.len()
348 }
349
350 pub fn query(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
352 let mut results = self.retriever.retrieve(query, k * 2)?;
354
355 results = self.reranker.rerank(query, &results, k)?;
357
358 Ok(results)
359 }
360
361 pub fn query_with_context(
363 &self,
364 query: &str,
365 k: usize,
366 ) -> Result<(Vec<RetrievalResult>, AssembledContext)> {
367 let results = self.query(query, k)?;
368 let context = self.assembler.assemble(&results);
369 Ok((results, context))
370 }
371
372 #[must_use]
374 pub fn assembler(&self) -> &ContextAssembler {
375 &self.assembler
376 }
377
378 #[must_use]
380 pub fn assemble_context(&self, results: &[RetrievalResult]) -> AssembledContext {
381 self.assembler.assemble(results)
382 }
383
384 #[must_use]
386 pub fn chunker(&self) -> &dyn Chunker {
387 self.chunker.as_ref()
388 }
389
390 #[must_use]
392 pub fn embedder(&self) -> &E {
393 &self.embedder
394 }
395}
396
397pub struct RagPipelineBuilder<E: Embedder, R: Reranker> {
399 chunker: Option<Box<dyn Chunker>>,
400 embedder: Option<E>,
401 vector_store: Option<VectorStore>,
402 sparse_index: Option<BM25Index>,
403 reranker: Option<R>,
404 fusion: FusionStrategy,
405 assembler_config: ContextAssemblerConfig,
406}
407
408impl<E: Embedder + Clone, R: Reranker> RagPipelineBuilder<E, R> {
409 #[must_use]
411 pub fn new() -> Self {
412 Self {
413 chunker: None,
414 embedder: None,
415 vector_store: None,
416 sparse_index: None,
417 reranker: None,
418 fusion: FusionStrategy::default(),
419 assembler_config: ContextAssemblerConfig::default(),
420 }
421 }
422
423 #[must_use]
425 pub fn chunker(mut self, chunker: impl Chunker + 'static) -> Self {
426 self.chunker = Some(Box::new(chunker));
427 self
428 }
429
430 #[must_use]
432 pub fn embedder(mut self, embedder: E) -> Self {
433 self.embedder = Some(embedder);
434 self
435 }
436
437 #[must_use]
439 pub fn vector_store(mut self, store: VectorStore) -> Self {
440 self.vector_store = Some(store);
441 self
442 }
443
444 #[must_use]
446 pub fn sparse_index(mut self, index: BM25Index) -> Self {
447 self.sparse_index = Some(index);
448 self
449 }
450
451 #[must_use]
453 pub fn reranker(mut self, reranker: R) -> Self {
454 self.reranker = Some(reranker);
455 self
456 }
457
458 #[must_use]
460 pub fn fusion(mut self, fusion: FusionStrategy) -> Self {
461 self.fusion = fusion;
462 self
463 }
464
465 #[must_use]
467 pub fn max_context_tokens(mut self, max_tokens: usize) -> Self {
468 self.assembler_config.max_tokens = max_tokens;
469 self
470 }
471
472 pub fn build(self) -> Result<RagPipeline<E, R>> {
474 let embedder =
475 self.embedder.ok_or_else(|| Error::InvalidConfig("embedder required".to_string()))?;
476
477 let reranker =
478 self.reranker.ok_or_else(|| Error::InvalidConfig("reranker required".to_string()))?;
479
480 let chunker = self.chunker.unwrap_or_else(|| Box::new(RecursiveChunker::new(512, 50)));
481
482 let vector_store =
483 self.vector_store.unwrap_or_else(|| VectorStore::with_dimension(embedder.dimension()));
484
485 let sparse_index = self.sparse_index.unwrap_or_default();
486
487 let retrieval_config = HybridRetrieverConfig { fusion: self.fusion, ..Default::default() };
488
489 let retriever = HybridRetriever::new(vector_store, sparse_index, embedder.clone())
490 .with_config(retrieval_config);
491
492 let assembler = ContextAssembler::new(self.assembler_config);
493
494 Ok(RagPipeline { chunker, embedder, retriever, reranker, assembler, document_count: 0 })
495 }
496}
497
498impl<E: Embedder + Clone, R: Reranker> Default for RagPipelineBuilder<E, R> {
499 fn default() -> Self {
500 Self::new()
501 }
502}
503
504#[must_use]
506pub fn pipeline_builder() -> RagPipelineBuilder<MockEmbedder, NoOpReranker> {
507 RagPipelineBuilder::new()
508}
509
510#[cfg(test)]
511mod tests;