ruvector_graph/hybrid/
rag_integration.rs

1//! RAG (Retrieval Augmented Generation) integration
2//!
3//! Provides graph-based context retrieval and multi-hop reasoning for LLMs.
4
5use crate::error::{GraphError, Result};
6use crate::hybrid::semantic_search::{SemanticPath, SemanticSearch};
7use crate::types::{EdgeId, NodeId, Properties};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Configuration for RAG engine
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RagConfig {
14    /// Maximum context size (in tokens)
15    pub max_context_tokens: usize,
16    /// Number of top documents to retrieve
17    pub top_k_docs: usize,
18    /// Maximum reasoning depth (hops in graph)
19    pub max_reasoning_depth: usize,
20    /// Minimum relevance score
21    pub min_relevance: f32,
22    /// Enable multi-hop reasoning
23    pub multi_hop_reasoning: bool,
24}
25
26impl Default for RagConfig {
27    fn default() -> Self {
28        Self {
29            max_context_tokens: 4096,
30            top_k_docs: 5,
31            max_reasoning_depth: 3,
32            min_relevance: 0.7,
33            multi_hop_reasoning: true,
34        }
35    }
36}
37
38/// RAG engine for graph-based retrieval
39pub struct RagEngine {
40    /// Semantic search engine
41    semantic_search: SemanticSearch,
42    /// Configuration
43    config: RagConfig,
44}
45
46impl RagEngine {
47    /// Create a new RAG engine
48    pub fn new(semantic_search: SemanticSearch, config: RagConfig) -> Self {
49        Self {
50            semantic_search,
51            config,
52        }
53    }
54
55    /// Retrieve relevant context for a query
56    pub fn retrieve_context(&self, query: &[f32]) -> Result<Context> {
57        // Find top-k most relevant documents
58        let matches = self
59            .semantic_search
60            .find_similar_nodes(query, self.config.top_k_docs)?;
61
62        let mut documents = Vec::new();
63        for match_result in matches {
64            if match_result.score >= self.config.min_relevance {
65                documents.push(Document {
66                    node_id: match_result.node_id.clone(),
67                    content: format!("Document {}", match_result.node_id),
68                    metadata: HashMap::new(),
69                    relevance_score: match_result.score,
70                });
71            }
72        }
73
74        let total_tokens = self.estimate_tokens(&documents);
75
76        Ok(Context {
77            documents,
78            total_tokens,
79            query_embedding: query.to_vec(),
80        })
81    }
82
83    /// Build multi-hop reasoning paths
84    pub fn build_reasoning_paths(
85        &self,
86        start_node: &NodeId,
87        query: &[f32],
88    ) -> Result<Vec<ReasoningPath>> {
89        if !self.config.multi_hop_reasoning {
90            return Ok(Vec::new());
91        }
92
93        // Find semantic paths through the graph
94        let semantic_paths =
95            self.semantic_search
96                .find_semantic_paths(start_node, query, self.config.top_k_docs)?;
97
98        // Convert semantic paths to reasoning paths
99        let reasoning_paths = semantic_paths
100            .into_iter()
101            .map(|path| self.convert_to_reasoning_path(path))
102            .collect();
103
104        Ok(reasoning_paths)
105    }
106
107    /// Aggregate evidence from multiple sources
108    pub fn aggregate_evidence(&self, paths: &[ReasoningPath]) -> Result<Vec<Evidence>> {
109        let mut evidence_map: HashMap<NodeId, Evidence> = HashMap::new();
110
111        for path in paths {
112            for step in &path.steps {
113                evidence_map
114                    .entry(step.node_id.clone())
115                    .and_modify(|e| {
116                        e.support_count += 1;
117                        e.confidence = e.confidence.max(step.confidence);
118                    })
119                    .or_insert_with(|| Evidence {
120                        node_id: step.node_id.clone(),
121                        content: step.content.clone(),
122                        support_count: 1,
123                        confidence: step.confidence,
124                        sources: vec![step.node_id.clone()],
125                    });
126            }
127        }
128
129        let mut evidence: Vec<_> = evidence_map.into_values().collect();
130        evidence.sort_by(|a, b| {
131            b.confidence
132                .partial_cmp(&a.confidence)
133                .unwrap_or(std::cmp::Ordering::Equal)
134        });
135
136        Ok(evidence)
137    }
138
139    /// Generate context-aware prompt
140    pub fn generate_prompt(&self, query: &str, context: &Context) -> String {
141        let mut prompt = String::new();
142
143        prompt.push_str("Based on the following context, answer the question.\n\n");
144        prompt.push_str("Context:\n");
145
146        for (i, doc) in context.documents.iter().enumerate() {
147            prompt.push_str(&format!(
148                "{}. {} (relevance: {:.2})\n",
149                i + 1,
150                doc.content,
151                doc.relevance_score
152            ));
153        }
154
155        prompt.push_str("\nQuestion: ");
156        prompt.push_str(query);
157        prompt.push_str("\n\nAnswer:");
158
159        prompt
160    }
161
162    /// Rerank results based on graph structure
163    pub fn rerank_results(
164        &self,
165        initial_results: Vec<Document>,
166        _query: &[f32],
167    ) -> Result<Vec<Document>> {
168        // Simple reranking based on score
169        // Real implementation would consider:
170        // - Graph centrality
171        // - Cross-document connections
172        // - Temporal relevance
173        // - User preferences
174
175        let mut results = initial_results;
176        results.sort_by(|a, b| {
177            b.relevance_score
178                .partial_cmp(&a.relevance_score)
179                .unwrap_or(std::cmp::Ordering::Equal)
180        });
181
182        Ok(results)
183    }
184
185    /// Convert semantic path to reasoning path
186    fn convert_to_reasoning_path(&self, semantic_path: SemanticPath) -> ReasoningPath {
187        let steps = semantic_path
188            .nodes
189            .iter()
190            .map(|node_id| ReasoningStep {
191                node_id: node_id.clone(),
192                content: format!("Step at node {}", node_id),
193                relationship: "RELATED_TO".to_string(),
194                confidence: semantic_path.semantic_score,
195            })
196            .collect();
197
198        ReasoningPath {
199            steps,
200            total_confidence: semantic_path.combined_score,
201            explanation: format!("Reasoning path with {} steps", semantic_path.nodes.len()),
202        }
203    }
204
205    /// Estimate token count for documents
206    fn estimate_tokens(&self, documents: &[Document]) -> usize {
207        // Rough estimation: ~4 characters per token
208        documents.iter().map(|doc| doc.content.len() / 4).sum()
209    }
210}
211
212/// Retrieved context for generation
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct Context {
215    /// Retrieved documents
216    pub documents: Vec<Document>,
217    /// Total estimated tokens
218    pub total_tokens: usize,
219    /// Original query embedding
220    pub query_embedding: Vec<f32>,
221}
222
223/// A retrieved document
224#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct Document {
226    pub node_id: NodeId,
227    pub content: String,
228    pub metadata: HashMap<String, String>,
229    pub relevance_score: f32,
230}
231
232/// A multi-hop reasoning path
233#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct ReasoningPath {
235    /// Steps in the reasoning chain
236    pub steps: Vec<ReasoningStep>,
237    /// Overall confidence in this path
238    pub total_confidence: f32,
239    /// Human-readable explanation
240    pub explanation: String,
241}
242
243/// A single step in reasoning
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct ReasoningStep {
246    pub node_id: NodeId,
247    pub content: String,
248    pub relationship: String,
249    pub confidence: f32,
250}
251
252/// Aggregated evidence from multiple paths
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct Evidence {
255    pub node_id: NodeId,
256    pub content: String,
257    pub support_count: usize,
258    pub confidence: f32,
259    pub sources: Vec<NodeId>,
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::hybrid::semantic_search::SemanticSearchConfig;
266    use crate::hybrid::vector_index::{EmbeddingConfig, HybridIndex};
267
268    #[test]
269    fn test_rag_engine_creation() {
270        let index = HybridIndex::new(EmbeddingConfig::default()).unwrap();
271        let semantic_search = SemanticSearch::new(index, SemanticSearchConfig::default());
272        let _rag = RagEngine::new(semantic_search, RagConfig::default());
273    }
274
275    #[test]
276    fn test_context_retrieval() -> Result<()> {
277        use crate::hybrid::vector_index::VectorIndexType;
278
279        let config = EmbeddingConfig {
280            dimensions: 4,
281            ..Default::default()
282        };
283        let index = HybridIndex::new(config)?;
284        // Initialize the node index
285        index.initialize_index(VectorIndexType::Node)?;
286
287        // Add test embeddings so search returns results
288        index.add_node_embedding("doc1".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
289        index.add_node_embedding("doc2".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
290
291        let semantic_search = SemanticSearch::new(index, SemanticSearchConfig::default());
292        let rag = RagEngine::new(semantic_search, RagConfig::default());
293
294        let query = vec![1.0, 0.0, 0.0, 0.0];
295        let context = rag.retrieve_context(&query)?;
296
297        assert_eq!(context.query_embedding, query);
298        // Should find at least one document
299        assert!(!context.documents.is_empty());
300        Ok(())
301    }
302
303    #[test]
304    fn test_prompt_generation() {
305        let index = HybridIndex::new(EmbeddingConfig::default()).unwrap();
306        let semantic_search = SemanticSearch::new(index, SemanticSearchConfig::default());
307        let rag = RagEngine::new(semantic_search, RagConfig::default());
308
309        let context = Context {
310            documents: vec![Document {
311                node_id: "doc1".to_string(),
312                content: "Test content".to_string(),
313                metadata: HashMap::new(),
314                relevance_score: 0.9,
315            }],
316            total_tokens: 100,
317            query_embedding: vec![1.0; 4],
318        };
319
320        let prompt = rag.generate_prompt("What is the answer?", &context);
321        assert!(prompt.contains("Test content"));
322        assert!(prompt.contains("What is the answer?"));
323    }
324}