1use crate::error::{GraphError, Result};
6use crate::hybrid::semantic_search::{SemanticPath, SemanticSearch};
7use crate::types::{EdgeId, NodeId, Properties};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RagConfig {
14 pub max_context_tokens: usize,
16 pub top_k_docs: usize,
18 pub max_reasoning_depth: usize,
20 pub min_relevance: f32,
22 pub multi_hop_reasoning: bool,
24}
25
26impl Default for RagConfig {
27 fn default() -> Self {
28 Self {
29 max_context_tokens: 4096,
30 top_k_docs: 5,
31 max_reasoning_depth: 3,
32 min_relevance: 0.7,
33 multi_hop_reasoning: true,
34 }
35 }
36}
37
38pub struct RagEngine {
40 semantic_search: SemanticSearch,
42 config: RagConfig,
44}
45
46impl RagEngine {
47 pub fn new(semantic_search: SemanticSearch, config: RagConfig) -> Self {
49 Self {
50 semantic_search,
51 config,
52 }
53 }
54
55 pub fn retrieve_context(&self, query: &[f32]) -> Result<Context> {
57 let matches = self
59 .semantic_search
60 .find_similar_nodes(query, self.config.top_k_docs)?;
61
62 let mut documents = Vec::new();
63 for match_result in matches {
64 if match_result.score >= self.config.min_relevance {
65 documents.push(Document {
66 node_id: match_result.node_id.clone(),
67 content: format!("Document {}", match_result.node_id),
68 metadata: HashMap::new(),
69 relevance_score: match_result.score,
70 });
71 }
72 }
73
74 let total_tokens = self.estimate_tokens(&documents);
75
76 Ok(Context {
77 documents,
78 total_tokens,
79 query_embedding: query.to_vec(),
80 })
81 }
82
83 pub fn build_reasoning_paths(
85 &self,
86 start_node: &NodeId,
87 query: &[f32],
88 ) -> Result<Vec<ReasoningPath>> {
89 if !self.config.multi_hop_reasoning {
90 return Ok(Vec::new());
91 }
92
93 let semantic_paths =
95 self.semantic_search
96 .find_semantic_paths(start_node, query, self.config.top_k_docs)?;
97
98 let reasoning_paths = semantic_paths
100 .into_iter()
101 .map(|path| self.convert_to_reasoning_path(path))
102 .collect();
103
104 Ok(reasoning_paths)
105 }
106
107 pub fn aggregate_evidence(&self, paths: &[ReasoningPath]) -> Result<Vec<Evidence>> {
109 let mut evidence_map: HashMap<NodeId, Evidence> = HashMap::new();
110
111 for path in paths {
112 for step in &path.steps {
113 evidence_map
114 .entry(step.node_id.clone())
115 .and_modify(|e| {
116 e.support_count += 1;
117 e.confidence = e.confidence.max(step.confidence);
118 })
119 .or_insert_with(|| Evidence {
120 node_id: step.node_id.clone(),
121 content: step.content.clone(),
122 support_count: 1,
123 confidence: step.confidence,
124 sources: vec![step.node_id.clone()],
125 });
126 }
127 }
128
129 let mut evidence: Vec<_> = evidence_map.into_values().collect();
130 evidence.sort_by(|a, b| {
131 b.confidence
132 .partial_cmp(&a.confidence)
133 .unwrap_or(std::cmp::Ordering::Equal)
134 });
135
136 Ok(evidence)
137 }
138
139 pub fn generate_prompt(&self, query: &str, context: &Context) -> String {
141 let mut prompt = String::new();
142
143 prompt.push_str("Based on the following context, answer the question.\n\n");
144 prompt.push_str("Context:\n");
145
146 for (i, doc) in context.documents.iter().enumerate() {
147 prompt.push_str(&format!(
148 "{}. {} (relevance: {:.2})\n",
149 i + 1,
150 doc.content,
151 doc.relevance_score
152 ));
153 }
154
155 prompt.push_str("\nQuestion: ");
156 prompt.push_str(query);
157 prompt.push_str("\n\nAnswer:");
158
159 prompt
160 }
161
162 pub fn rerank_results(
164 &self,
165 initial_results: Vec<Document>,
166 _query: &[f32],
167 ) -> Result<Vec<Document>> {
168 let mut results = initial_results;
176 results.sort_by(|a, b| {
177 b.relevance_score
178 .partial_cmp(&a.relevance_score)
179 .unwrap_or(std::cmp::Ordering::Equal)
180 });
181
182 Ok(results)
183 }
184
185 fn convert_to_reasoning_path(&self, semantic_path: SemanticPath) -> ReasoningPath {
187 let steps = semantic_path
188 .nodes
189 .iter()
190 .map(|node_id| ReasoningStep {
191 node_id: node_id.clone(),
192 content: format!("Step at node {}", node_id),
193 relationship: "RELATED_TO".to_string(),
194 confidence: semantic_path.semantic_score,
195 })
196 .collect();
197
198 ReasoningPath {
199 steps,
200 total_confidence: semantic_path.combined_score,
201 explanation: format!("Reasoning path with {} steps", semantic_path.nodes.len()),
202 }
203 }
204
205 fn estimate_tokens(&self, documents: &[Document]) -> usize {
207 documents.iter().map(|doc| doc.content.len() / 4).sum()
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct Context {
215 pub documents: Vec<Document>,
217 pub total_tokens: usize,
219 pub query_embedding: Vec<f32>,
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct Document {
226 pub node_id: NodeId,
227 pub content: String,
228 pub metadata: HashMap<String, String>,
229 pub relevance_score: f32,
230}
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct ReasoningPath {
235 pub steps: Vec<ReasoningStep>,
237 pub total_confidence: f32,
239 pub explanation: String,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct ReasoningStep {
246 pub node_id: NodeId,
247 pub content: String,
248 pub relationship: String,
249 pub confidence: f32,
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct Evidence {
255 pub node_id: NodeId,
256 pub content: String,
257 pub support_count: usize,
258 pub confidence: f32,
259 pub sources: Vec<NodeId>,
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::hybrid::semantic_search::SemanticSearchConfig;
266 use crate::hybrid::vector_index::{EmbeddingConfig, HybridIndex};
267
268 #[test]
269 fn test_rag_engine_creation() {
270 let index = HybridIndex::new(EmbeddingConfig::default()).unwrap();
271 let semantic_search = SemanticSearch::new(index, SemanticSearchConfig::default());
272 let _rag = RagEngine::new(semantic_search, RagConfig::default());
273 }
274
275 #[test]
276 fn test_context_retrieval() -> Result<()> {
277 use crate::hybrid::vector_index::VectorIndexType;
278
279 let config = EmbeddingConfig {
280 dimensions: 4,
281 ..Default::default()
282 };
283 let index = HybridIndex::new(config)?;
284 index.initialize_index(VectorIndexType::Node)?;
286
287 index.add_node_embedding("doc1".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
289 index.add_node_embedding("doc2".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
290
291 let semantic_search = SemanticSearch::new(index, SemanticSearchConfig::default());
292 let rag = RagEngine::new(semantic_search, RagConfig::default());
293
294 let query = vec![1.0, 0.0, 0.0, 0.0];
295 let context = rag.retrieve_context(&query)?;
296
297 assert_eq!(context.query_embedding, query);
298 assert!(!context.documents.is_empty());
300 Ok(())
301 }
302
303 #[test]
304 fn test_prompt_generation() {
305 let index = HybridIndex::new(EmbeddingConfig::default()).unwrap();
306 let semantic_search = SemanticSearch::new(index, SemanticSearchConfig::default());
307 let rag = RagEngine::new(semantic_search, RagConfig::default());
308
309 let context = Context {
310 documents: vec![Document {
311 node_id: "doc1".to_string(),
312 content: "Test content".to_string(),
313 metadata: HashMap::new(),
314 relevance_score: 0.9,
315 }],
316 total_tokens: 100,
317 query_embedding: vec![1.0; 4],
318 };
319
320 let prompt = rag.generate_prompt("What is the answer?", &context);
321 assert!(prompt.contains("Test content"));
322 assert!(prompt.contains("What is the answer?"));
323 }
324}