Skip to main content

trueno_rag/pipeline/
mod.rs

1//! RAG Pipeline implementation with context assembly
2
3use crate::{
4    chunk::{Chunk, Chunker, RecursiveChunker},
5    embed::{Embedder, MockEmbedder},
6    fusion::FusionStrategy,
7    index::{BM25Index, VectorStore},
8    rerank::{NoOpReranker, Reranker},
9    retrieve::{HybridRetriever, HybridRetrieverConfig, RetrievalResult},
10    Document, DocumentId, Error, Result,
11};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15/// Default embedding dimension (all-MiniLM-L6-v2 / BGE-small-en-v1.5)
16const DEFAULT_EMBEDDING_DIM: usize = 384;
17
18/// Citation for a retrieved chunk
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Citation {
21    /// Citation ID (1-indexed for display)
22    pub id: usize,
23    /// Source document ID
24    pub document_id: DocumentId,
25    /// Source chunk ID
26    pub chunk_id: crate::ChunkId,
27    /// Document title (if available)
28    pub title: Option<String>,
29    /// Source URL (if available)
30    pub url: Option<String>,
31    /// Page number (if available)
32    pub page: Option<usize>,
33}
34
35/// A chunk in the assembled context
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ContextChunk {
38    /// The chunk content
39    pub content: String,
40    /// Citation ID
41    pub citation_id: usize,
42    /// Retrieval score
43    pub retrieval_score: f32,
44    /// Rerank score (if available)
45    pub rerank_score: Option<f32>,
46}
47
48/// Assembled context from retrieval results
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct AssembledContext {
51    /// Ordered chunks in context
52    pub chunks: Vec<ContextChunk>,
53    /// Total token count (estimated)
54    pub total_tokens: usize,
55    /// Source citations
56    pub citations: Vec<Citation>,
57}
58
59impl AssembledContext {
60    /// Create a new empty context
61    #[must_use]
62    pub fn new() -> Self {
63        Self { chunks: Vec::new(), total_tokens: 0, citations: Vec::new() }
64    }
65
66    /// Add a chunk to the context
67    pub fn add_chunk(&mut self, result: &RetrievalResult, citation_id: usize) {
68        let chunk = ContextChunk {
69            content: result.chunk.content.clone(),
70            citation_id,
71            retrieval_score: result.best_score(),
72            rerank_score: result.rerank_score,
73        };
74
75        // Estimate tokens (rough: ~4 chars per token for English)
76        self.total_tokens += result.chunk.content.len() / 4;
77        self.chunks.push(chunk);
78    }
79
80    /// Add a citation
81    pub fn add_citation(&mut self, result: &RetrievalResult) -> usize {
82        let id = self.citations.len() + 1;
83
84        let citation = Citation {
85            id,
86            document_id: result.chunk.document_id,
87            chunk_id: result.chunk.id,
88            title: result.chunk.metadata.title.clone(),
89            url: None, // Would come from document metadata
90            page: result.chunk.metadata.page,
91        };
92
93        self.citations.push(citation);
94        id
95    }
96
97    /// Format context with inline citations
98    #[must_use]
99    pub fn format_with_citations(&self) -> String {
100        self.chunks
101            .iter()
102            .map(|c| format!("{} [{}]", c.content, c.citation_id))
103            .collect::<Vec<_>>()
104            .join("\n\n")
105    }
106
107    /// Format context without citations
108    #[must_use]
109    pub fn format_plain(&self) -> String {
110        self.chunks.iter().map(|c| c.content.as_str()).collect::<Vec<_>>().join("\n\n")
111    }
112
113    /// Generate citation list
114    #[must_use]
115    pub fn citation_list(&self) -> String {
116        self.citations
117            .iter()
118            .map(|c| {
119                let title = c.title.as_deref().unwrap_or("Untitled");
120                format!("[{}] {}", c.id, title)
121            })
122            .collect::<Vec<_>>()
123            .join("\n")
124    }
125
126    /// Get the number of chunks
127    #[must_use]
128    pub fn len(&self) -> usize {
129        self.chunks.len()
130    }
131
132    /// Check if the context is empty
133    #[must_use]
134    pub fn is_empty(&self) -> bool {
135        self.chunks.is_empty()
136    }
137}
138
139impl Default for AssembledContext {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145/// Strategy for assembling context from retrieval results
146#[derive(Debug, Clone, Default, Serialize, Deserialize)]
147pub enum AssemblyStrategy {
148    /// Simple concatenation in rank order
149    #[default]
150    Sequential,
151    /// Group by document, then by rank
152    DocumentGrouped,
153    /// Interleave chunks for diversity
154    Interleaved,
155}
156
157/// Context assembler configuration
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct ContextAssemblerConfig {
160    /// Maximum context length in tokens (estimated)
161    pub max_tokens: usize,
162    /// Assembly strategy
163    pub strategy: AssemblyStrategy,
164    /// Include citations
165    pub include_citations: bool,
166}
167
168impl Default for ContextAssemblerConfig {
169    fn default() -> Self {
170        Self { max_tokens: 4096, strategy: AssemblyStrategy::Sequential, include_citations: true }
171    }
172}
173
174/// Assembles retrieved chunks into a coherent context
175#[derive(Debug, Clone)]
176pub struct ContextAssembler {
177    config: ContextAssemblerConfig,
178}
179
180impl ContextAssembler {
181    /// Create a new context assembler
182    #[must_use]
183    pub fn new(config: ContextAssemblerConfig) -> Self {
184        Self { config }
185    }
186
187    /// Create with default configuration
188    #[must_use]
189    pub fn with_max_tokens(max_tokens: usize) -> Self {
190        Self::new(ContextAssemblerConfig { max_tokens, ..Default::default() })
191    }
192
193    /// Assemble context from retrieval results
194    #[must_use]
195    pub fn assemble(&self, results: &[RetrievalResult]) -> AssembledContext {
196        match self.config.strategy {
197            AssemblyStrategy::Sequential => self.assemble_sequential(results),
198            AssemblyStrategy::DocumentGrouped => self.assemble_grouped(results),
199            AssemblyStrategy::Interleaved => self.assemble_interleaved(results),
200        }
201    }
202
203    fn assemble_sequential(&self, results: &[RetrievalResult]) -> AssembledContext {
204        let mut context = AssembledContext::new();
205        let mut remaining_tokens = self.config.max_tokens;
206
207        for result in results {
208            let chunk_tokens = result.chunk.content.len() / 4; // Rough estimate
209
210            if chunk_tokens > remaining_tokens {
211                // Could truncate, but for now we just stop
212                break;
213            }
214
215            let citation_id =
216                if self.config.include_citations { context.add_citation(result) } else { 0 };
217
218            context.add_chunk(result, citation_id);
219            remaining_tokens = remaining_tokens.saturating_sub(chunk_tokens);
220        }
221
222        context
223    }
224
225    fn assemble_grouped(&self, results: &[RetrievalResult]) -> AssembledContext {
226        // Group by document
227        let mut by_doc: HashMap<DocumentId, Vec<&RetrievalResult>> = HashMap::new();
228        for result in results {
229            by_doc.entry(result.chunk.document_id).or_default().push(result);
230        }
231
232        // Flatten while respecting order within documents
233        let mut context = AssembledContext::new();
234        let mut remaining_tokens = self.config.max_tokens;
235
236        for (_, doc_results) in by_doc {
237            for result in doc_results {
238                let chunk_tokens = result.chunk.content.len() / 4;
239
240                if chunk_tokens > remaining_tokens {
241                    break;
242                }
243
244                let citation_id =
245                    if self.config.include_citations { context.add_citation(result) } else { 0 };
246
247                context.add_chunk(result, citation_id);
248                remaining_tokens = remaining_tokens.saturating_sub(chunk_tokens);
249            }
250        }
251
252        context
253    }
254
255    fn assemble_interleaved(&self, results: &[RetrievalResult]) -> AssembledContext {
256        // For now, same as sequential but could implement round-robin from different docs
257        self.assemble_sequential(results)
258    }
259}
260
261impl Default for ContextAssembler {
262    fn default() -> Self {
263        Self::new(ContextAssemblerConfig::default())
264    }
265}
266
267/// RAG Pipeline configuration
268#[derive(Debug, Clone)]
269pub struct RagPipelineConfig {
270    /// Chunking chunk size
271    pub chunk_size: usize,
272    /// Chunking overlap
273    pub chunk_overlap: usize,
274    /// Embedding dimension
275    pub embedding_dimension: usize,
276    /// Retrieval config
277    pub retrieval: HybridRetrieverConfig,
278    /// Context assembly config
279    pub context: ContextAssemblerConfig,
280}
281
282impl Default for RagPipelineConfig {
283    fn default() -> Self {
284        Self {
285            chunk_size: 512,
286            chunk_overlap: 50,
287            embedding_dimension: DEFAULT_EMBEDDING_DIM,
288            retrieval: HybridRetrieverConfig::default(),
289            context: ContextAssemblerConfig::default(),
290        }
291    }
292}
293
294/// Complete RAG pipeline
295pub struct RagPipeline<E: Embedder, R: Reranker> {
296    /// Document chunker
297    chunker: Box<dyn Chunker>,
298    /// Embedder
299    embedder: E,
300    /// Hybrid retriever
301    retriever: HybridRetriever<E>,
302    /// Reranker
303    reranker: R,
304    /// Context assembler
305    assembler: ContextAssembler,
306    /// Indexed document count
307    document_count: usize,
308}
309
310impl<E: Embedder + Clone, R: Reranker> RagPipeline<E, R> {
311    /// Index a single document
312    pub fn index_document(&mut self, document: &Document) -> Result<Vec<Chunk>> {
313        // Chunk the document
314        let mut chunks = self.chunker.chunk(document)?;
315
316        // Embed the chunks
317        self.embedder.embed_chunks(&mut chunks)?;
318
319        // Add to retriever (both dense and sparse indices)
320        for chunk in &chunks {
321            self.retriever.index(chunk.clone())?;
322        }
323
324        self.document_count += 1;
325        Ok(chunks)
326    }
327
328    /// Index multiple documents
329    pub fn index_documents(&mut self, documents: &[Document]) -> Result<usize> {
330        let mut total_chunks = 0;
331        for doc in documents {
332            let chunks = self.index_document(doc)?;
333            total_chunks += chunks.len();
334        }
335        Ok(total_chunks)
336    }
337
338    /// Get the number of indexed documents
339    #[must_use]
340    pub fn document_count(&self) -> usize {
341        self.document_count
342    }
343
344    /// Get the number of indexed chunks
345    #[must_use]
346    pub fn chunk_count(&self) -> usize {
347        self.retriever.len()
348    }
349
350    /// Query the pipeline
351    pub fn query(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
352        // Retrieve
353        let mut results = self.retriever.retrieve(query, k * 2)?;
354
355        // Rerank
356        results = self.reranker.rerank(query, &results, k)?;
357
358        Ok(results)
359    }
360
361    /// Query and assemble context
362    pub fn query_with_context(
363        &self,
364        query: &str,
365        k: usize,
366    ) -> Result<(Vec<RetrievalResult>, AssembledContext)> {
367        let results = self.query(query, k)?;
368        let context = self.assembler.assemble(&results);
369        Ok((results, context))
370    }
371
372    /// Get the context assembler
373    #[must_use]
374    pub fn assembler(&self) -> &ContextAssembler {
375        &self.assembler
376    }
377
378    /// Assemble context from results
379    #[must_use]
380    pub fn assemble_context(&self, results: &[RetrievalResult]) -> AssembledContext {
381        self.assembler.assemble(results)
382    }
383
384    /// Get the chunker
385    #[must_use]
386    pub fn chunker(&self) -> &dyn Chunker {
387        self.chunker.as_ref()
388    }
389
390    /// Get the embedder
391    #[must_use]
392    pub fn embedder(&self) -> &E {
393        &self.embedder
394    }
395}
396
397/// Builder for RAG pipeline
398pub struct RagPipelineBuilder<E: Embedder, R: Reranker> {
399    chunker: Option<Box<dyn Chunker>>,
400    embedder: Option<E>,
401    vector_store: Option<VectorStore>,
402    sparse_index: Option<BM25Index>,
403    reranker: Option<R>,
404    fusion: FusionStrategy,
405    assembler_config: ContextAssemblerConfig,
406}
407
408impl<E: Embedder + Clone, R: Reranker> RagPipelineBuilder<E, R> {
409    /// Create a new pipeline builder
410    #[must_use]
411    pub fn new() -> Self {
412        Self {
413            chunker: None,
414            embedder: None,
415            vector_store: None,
416            sparse_index: None,
417            reranker: None,
418            fusion: FusionStrategy::default(),
419            assembler_config: ContextAssemblerConfig::default(),
420        }
421    }
422
423    /// Set the chunker
424    #[must_use]
425    pub fn chunker(mut self, chunker: impl Chunker + 'static) -> Self {
426        self.chunker = Some(Box::new(chunker));
427        self
428    }
429
430    /// Set the embedder
431    #[must_use]
432    pub fn embedder(mut self, embedder: E) -> Self {
433        self.embedder = Some(embedder);
434        self
435    }
436
437    /// Set the vector store
438    #[must_use]
439    pub fn vector_store(mut self, store: VectorStore) -> Self {
440        self.vector_store = Some(store);
441        self
442    }
443
444    /// Set the sparse index
445    #[must_use]
446    pub fn sparse_index(mut self, index: BM25Index) -> Self {
447        self.sparse_index = Some(index);
448        self
449    }
450
451    /// Set the reranker
452    #[must_use]
453    pub fn reranker(mut self, reranker: R) -> Self {
454        self.reranker = Some(reranker);
455        self
456    }
457
458    /// Set the fusion strategy
459    #[must_use]
460    pub fn fusion(mut self, fusion: FusionStrategy) -> Self {
461        self.fusion = fusion;
462        self
463    }
464
465    /// Set max context tokens
466    #[must_use]
467    pub fn max_context_tokens(mut self, max_tokens: usize) -> Self {
468        self.assembler_config.max_tokens = max_tokens;
469        self
470    }
471
472    /// Build the pipeline
473    pub fn build(self) -> Result<RagPipeline<E, R>> {
474        let embedder =
475            self.embedder.ok_or_else(|| Error::InvalidConfig("embedder required".to_string()))?;
476
477        let reranker =
478            self.reranker.ok_or_else(|| Error::InvalidConfig("reranker required".to_string()))?;
479
480        let chunker = self.chunker.unwrap_or_else(|| Box::new(RecursiveChunker::new(512, 50)));
481
482        let vector_store =
483            self.vector_store.unwrap_or_else(|| VectorStore::with_dimension(embedder.dimension()));
484
485        let sparse_index = self.sparse_index.unwrap_or_default();
486
487        let retrieval_config = HybridRetrieverConfig { fusion: self.fusion, ..Default::default() };
488
489        let retriever = HybridRetriever::new(vector_store, sparse_index, embedder.clone())
490            .with_config(retrieval_config);
491
492        let assembler = ContextAssembler::new(self.assembler_config);
493
494        Ok(RagPipeline { chunker, embedder, retriever, reranker, assembler, document_count: 0 })
495    }
496}
497
498impl<E: Embedder + Clone, R: Reranker> Default for RagPipelineBuilder<E, R> {
499    fn default() -> Self {
500        Self::new()
501    }
502}
503
504/// Simplified pipeline builder with defaults
505#[must_use]
506pub fn pipeline_builder() -> RagPipelineBuilder<MockEmbedder, NoOpReranker> {
507    RagPipelineBuilder::new()
508}
509
510#[cfg(test)]
511mod tests;