Skip to main content

synapse_core/
server.rs

1use std::sync::Arc;
2use tonic::{Request, Response, Status};
3use dashmap::DashMap;
4
5pub mod proto {
6    tonic::include_proto!("semantic_engine");
7}
8
9use proto::semantic_engine_server::SemanticEngine;
10use proto::*;
11
12use crate::store::SynapseStore;
13use crate::reasoner::{SynapseReasoner, ReasoningStrategy as InternalStrategy};
14use crate::server::proto::{ReasoningStrategy, SearchMode};
15use crate::ingest::IngestionEngine;
16use std::path::Path;
17
18use crate::auth::NamespaceAuth;
19use crate::audit::InferenceAudit;
20
21pub struct MySemanticEngine {
22    pub storage_path: String,
23    pub stores: DashMap<String, Arc<SynapseStore>>,
24    pub auth: Arc<NamespaceAuth>,
25    pub audit: Arc<InferenceAudit>,
26}
27
28impl MySemanticEngine {
29    pub fn new(storage_path: &str) -> Self {
30        let auth = Arc::new(NamespaceAuth::new());
31        auth.load_from_env();
32        
33        Self {
34            storage_path: storage_path.to_string(),
35            stores: DashMap::new(),
36            auth,
37            audit: Arc::new(InferenceAudit::new()),
38        }
39    }
40
41    pub fn get_store(&self, namespace: &str) -> Result<Arc<SynapseStore>, Status> {
42        if let Some(store) = self.stores.get(namespace) {
43            return Ok(store.clone());
44        }
45
46        let store = SynapseStore::open(namespace, &self.storage_path)
47            .map_err(|e| Status::internal(format!("Failed to open store for namespace '{}': {}", namespace, e)))?;
48        
49        let store_arc = Arc::new(store);
50        self.stores.insert(namespace.to_string(), store_arc.clone());
51        Ok(store_arc)
52    }
53}
54
55#[tonic::async_trait]
56impl SemanticEngine for MySemanticEngine {
57    async fn ingest_triples(
58        &self,
59        request: Request<IngestRequest>,
60    ) -> Result<Response<IngestResponse>, Status> {
61        // Auth check (Write permission)
62        let token = request.metadata().get("authorization").and_then(|t| t.to_str().ok()).map(|s| s.trim_start_matches("Bearer ").to_string());
63        let req = request.into_inner();
64        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
65        
66        if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
67             return Err(Status::permission_denied(e));
68        }
69        
70        let store = self.get_store(namespace)?;
71
72        // Log provenance for audit
73        let timestamp = chrono::Utc::now().to_rfc3339();
74        let triple_count = req.triples.len();
75        let mut sources: Vec<String> = Vec::new();
76
77        let triples: Vec<(String, String, String)> = req
78            .triples
79            .into_iter()
80            .map(|t| {
81                // Capture provenance sources for logging
82                if let Some(ref prov) = t.provenance {
83                    if !prov.source.is_empty() && !sources.contains(&prov.source) {
84                        sources.push(prov.source.clone());
85                    }
86                }
87                (t.subject, t.predicate, t.object)
88            })
89            .collect();
90
91        match store.ingest_triples(triples).await {
92            Ok((added, _)) => {
93                // Log ingestion for audit trail
94                eprintln!(
95                    "INGEST [{timestamp}] namespace={namespace} triples={triple_count} added={added} sources={:?}",
96                    sources
97                );
98                Ok(Response::new(IngestResponse {
99                    nodes_added: added,
100                    edges_added: added,
101                }))
102            }
103            Err(e) => Err(Status::internal(e.to_string())),
104        }
105    }
106
107    async fn ingest_file(
108        &self,
109        request: Request<IngestFileRequest>,
110    ) -> Result<Response<IngestResponse>, Status> {
111        let req = request.into_inner();
112        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
113        let store = self.get_store(namespace)?;
114        
115        let engine = IngestionEngine::new(store);
116        let path = Path::new(&req.file_path);
117
118        match engine.ingest_file(path, namespace).await {
119            Ok(count) => Ok(Response::new(IngestResponse {
120                nodes_added: count,
121                edges_added: count,
122            })),
123            Err(e) => Err(Status::internal(e.to_string())),
124        }
125    }
126
127    async fn get_neighbors(
128        &self,
129        request: Request<NodeRequest>,
130    ) -> Result<Response<NeighborResponse>, Status> {
131        let req = request.into_inner();
132        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
133        let store = self.get_store(namespace)?;
134        
135        let direction = if req.direction.is_empty() { "outgoing" } else { &req.direction };
136        let edge_filter = if req.edge_filter.is_empty() { None } else { Some(req.edge_filter.as_str()) };
137        let max_depth = if req.depth == 0 { 1 } else { req.depth as usize };
138        let limit_per_layer = if req.limit_per_layer == 0 { usize::MAX } else { req.limit_per_layer as usize };
139
140        let mut neighbors = Vec::new();
141        let mut visited = std::collections::HashSet::new();
142        let mut current_frontier = Vec::new();
143
144        // Start with the initial node
145        if let Some(start_uri) = store.get_uri(req.node_id) {
146            current_frontier.push(start_uri.clone());
147            visited.insert(start_uri);
148        }
149
150        // BFS traversal up to max_depth
151        for current_depth in 1..=max_depth {
152            let mut next_frontier = Vec::new();
153            let mut layer_count = 0;
154            let score = 1.0 / current_depth as f32;  // Path scoring: closer = higher
155
156            for uri in &current_frontier {
157                if layer_count >= limit_per_layer {
158                    break;
159                }
160
161                // Query outgoing edges (URI as subject)
162                if direction == "outgoing" || direction == "both" {
163                    if let Ok(subj) = oxigraph::model::NamedNodeRef::new(uri) {
164                        for quad in store.store.quads_for_pattern(
165                            Some(subj.into()),
166                            None,
167                            None,
168                            None,
169                        ) {
170                            if layer_count >= limit_per_layer {
171                                break;
172                            }
173                            if let Ok(q) = quad {
174                                let pred = q.predicate.to_string();
175                                // Apply edge filter if specified
176                                if let Some(filter) = edge_filter {
177                                    if !pred.contains(filter) {
178                                        continue;
179                                    }
180                                }
181                                let obj_uri = q.object.to_string();
182                                if !visited.contains(&obj_uri) {
183                                    visited.insert(obj_uri.clone());
184                                    let obj_id = store.get_or_create_id(&obj_uri);
185                                    neighbors.push(Neighbor {
186                                        node_id: obj_id,
187                                        edge_type: pred,
188                                        uri: obj_uri.clone(),
189                                        direction: "outgoing".to_string(),
190                                        depth: current_depth as u32,
191                                        score,
192                                    });
193                                    next_frontier.push(obj_uri);
194                                    layer_count += 1;
195                                }
196                            }
197                        }
198                    }
199                }
200
201                // Query incoming edges (URI as object)
202                if direction == "incoming" || direction == "both" {
203                    if let Ok(obj) = oxigraph::model::NamedNodeRef::new(uri) {
204                        for quad in store.store.quads_for_pattern(
205                            None,
206                            None,
207                            Some(obj.into()),
208                            None,
209                        ) {
210                            if layer_count >= limit_per_layer {
211                                break;
212                            }
213                            if let Ok(q) = quad {
214                                let pred = q.predicate.to_string();
215                                // Apply edge filter if specified
216                                if let Some(filter) = edge_filter {
217                                    if !pred.contains(filter) {
218                                        continue;
219                                    }
220                                }
221                                let subj_uri = q.subject.to_string();
222                                if !visited.contains(&subj_uri) {
223                                    visited.insert(subj_uri.clone());
224                                    let subj_id = store.get_or_create_id(&subj_uri);
225                                    neighbors.push(Neighbor {
226                                        node_id: subj_id,
227                                        edge_type: pred,
228                                        uri: subj_uri.clone(),
229                                        direction: "incoming".to_string(),
230                                        depth: current_depth as u32,
231                                        score,
232                                    });
233                                    next_frontier.push(subj_uri);
234                                    layer_count += 1;
235                                }
236                            }
237                        }
238                    }
239                }
240            }
241
242            current_frontier = next_frontier;
243            if current_frontier.is_empty() {
244                break;
245            }
246        }
247
248        // Sort by score (highest first)
249        neighbors.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
250
251        Ok(Response::new(NeighborResponse { neighbors }))
252    }
253
254    async fn search(
255        &self,
256        request: Request<SearchRequest>,
257    ) -> Result<Response<SearchResponse>, Status> {
258        let req = request.into_inner();
259        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
260        let store = self.get_store(namespace)?;
261
262        match store.hybrid_search(&req.query, req.limit as usize, 0).await {
263            Ok(results) => {
264                let grpc_results = results
265                    .into_iter()
266                    .enumerate()
267                    .map(|(idx, (uri, score))| SearchResult {
268                        node_id: idx as u32,
269                        score,
270                        content: uri.clone(),
271                        uri,
272                    })
273                    .collect();
274                Ok(Response::new(SearchResponse { results: grpc_results }))
275            }
276            Err(e) => Err(Status::internal(e.to_string())),
277        }
278    }
279
280    async fn resolve_id(
281        &self,
282        request: Request<ResolveRequest>,
283    ) -> Result<Response<ResolveResponse>, Status> {
284        let req = request.into_inner();
285        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
286        let store = self.get_store(namespace)?;
287
288        // Look up the URI in our mapping
289        let uri_to_id = store.uri_to_id.read().unwrap();
290        if let Some(&node_id) = uri_to_id.get(&req.content) {
291            Ok(Response::new(ResolveResponse {
292                node_id,
293                found: true,
294            }))
295        } else {
296            Ok(Response::new(ResolveResponse {
297                node_id: 0,
298                found: false,
299            }))
300        }
301    }
302
303    async fn get_all_triples(
304        &self,
305        request: Request<EmptyRequest>,
306    ) -> Result<Response<TriplesResponse>, Status> {
307        let req = request.into_inner();
308        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
309        let store = self.get_store(namespace)?;
310        
311        let mut triples = Vec::new();
312
313        for quad in store.store.iter().map(|q| q.unwrap()) {
314            triples.push(Triple {
315                subject: quad.subject.to_string(),
316                predicate: quad.predicate.to_string(),
317                object: quad.object.to_string(),
318                provenance: Some(Provenance {
319                    source: "oxigraph".to_string(),
320                    timestamp: "".to_string(),
321                    method: "storage".to_string(),
322                }),
323                embedding: vec![],
324            });
325        }
326
327        Ok(Response::new(TriplesResponse { triples }))
328    }
329
330    async fn query_sparql(
331        &self,
332        request: Request<SparqlRequest>,
333    ) -> Result<Response<SparqlResponse>, Status> {
334        let req = request.into_inner();
335        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
336        let store = self.get_store(namespace)?;
337        
338        match store.query_sparql(&req.query) {
339            Ok(json) => Ok(Response::new(SparqlResponse { results_json: json })),
340            Err(e) => Err(Status::internal(e.to_string())),
341        }
342    }
343
344    async fn delete_namespace_data(
345        &self,
346        request: Request<EmptyRequest>,
347    ) -> Result<Response<DeleteResponse>, Status> {
348        let req = request.into_inner();
349        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
350        
351        // Remove from cache
352        self.stores.remove(namespace);
353        
354        // Delete directory
355        let path = Path::new(&self.storage_path).join(namespace);
356        if path.exists() {
357            std::fs::remove_dir_all(path).map_err(|e| Status::internal(e.to_string()))?;
358        }
359
360        Ok(Response::new(DeleteResponse {
361            success: true,
362            message: format!("Deleted namespace '{}'", namespace),
363        }))
364    }
365
366    async fn hybrid_search(
367        &self,
368        request: Request<HybridSearchRequest>,
369    ) -> Result<Response<SearchResponse>, Status> {
370        let req = request.into_inner();
371        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
372        let store = self.get_store(namespace)?;
373
374        let vector_k = req.vector_k as usize;
375        let graph_depth = req.graph_depth;
376
377        let results = match SearchMode::try_from(req.mode) {
378            Ok(SearchMode::VectorOnly) | Ok(SearchMode::Hybrid) => {
379                store.hybrid_search(&req.query, vector_k, graph_depth).await
380                    .map_err(|e| Status::internal(format!("Hybrid search failed: {}", e)))?
381            }
382            _ => vec![],
383        };
384
385        let grpc_results = results
386            .into_iter()
387            .enumerate()
388            .map(|(idx, (uri, score))| SearchResult {
389                node_id: idx as u32,
390                score,
391                content: uri.clone(),
392                uri,
393            })
394            .collect();
395
396        Ok(Response::new(SearchResponse { results: grpc_results }))
397    }
398
399    async fn apply_reasoning(
400        &self,
401        request: Request<ReasoningRequest>,
402    ) -> Result<Response<ReasoningResponse>, Status> {
403        // Auth check (Reason permission)
404        let token = request.metadata().get("authorization").and_then(|t| t.to_str().ok()).map(|s| s.trim_start_matches("Bearer ").to_string());
405        let req = request.into_inner();
406        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
407        
408        if let Err(e) = self.auth.check(token.as_deref(), namespace, "reason") {
409             return Err(Status::permission_denied(e));
410        }
411
412        let store = self.get_store(namespace)?;
413        
414        let strategy = match ReasoningStrategy::try_from(req.strategy) {
415            Ok(ReasoningStrategy::Rdfs) => InternalStrategy::RDFS,
416            Ok(ReasoningStrategy::Owlrl) => InternalStrategy::OWLRL,
417            _ => InternalStrategy::None,
418        };
419        let strategy_name = format!("{:?}", strategy);
420
421        let reasoner = SynapseReasoner::new(strategy);
422        let start_triples = store.store.len().unwrap_or(0);
423        
424        let response = if req.materialize {
425            match reasoner.materialize(&store.store) {
426                Ok(count) => {
427                    Ok(Response::new(ReasoningResponse {
428                        success: true,
429                        triples_inferred: count as u32,
430                        message: format!("Materialized {} triples in namespace '{}'", count, namespace),
431                    }))
432                },
433                Err(e) => Err(Status::internal(e.to_string())),
434            }
435        } else {
436            match reasoner.apply(&store.store) {
437                Ok(triples) => {
438                    Ok(Response::new(ReasoningResponse {
439                        success: true,
440                        triples_inferred: triples.len() as u32,
441                        message: format!("Found {} inferred triples in namespace '{}'", triples.len(), namespace),
442                    }))
443                },
444                Err(e) => Err(Status::internal(e.to_string())),
445            }
446        };
447
448        // Audit Log
449        if let Ok(ref res) = response {
450            let inferred = res.get_ref().triples_inferred as usize;
451            self.audit.log(
452                namespace,
453                &strategy_name,
454                start_triples,
455                inferred,
456                0, // Duplicates skipped not easily tracked here without changing reasoner return signature
457                vec![] // Sample inferences
458            );
459        }
460
461        response
462    }
463}
464
465pub async fn run_mcp_stdio(engine: Arc<MySemanticEngine>) -> Result<(), Box<dyn std::error::Error>> {
466    let server = crate::mcp_stdio::McpStdioServer::new(engine);
467    server.run().await
468}