Skip to main content

synapse_core/
server.rs

1use dashmap::DashMap;
2use std::sync::Arc;
3use tonic::{Request, Response, Status};
4
5pub mod proto {
6    tonic::include_proto!("semantic_engine");
7}
8
9use proto::semantic_engine_server::SemanticEngine;
10use proto::*;
11
12use crate::ingest::IngestionEngine;
13use crate::reasoner::{ReasoningStrategy as InternalStrategy, SynapseReasoner};
14use crate::server::proto::{ReasoningStrategy, SearchMode};
15use crate::store::SynapseStore;
16use std::path::Path;
17
18use crate::audit::InferenceAudit;
19use crate::auth::NamespaceAuth;
20
21#[derive(Clone)]
22pub struct AuthToken(pub String);
23
24pub fn auth_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
25    if let Some(token) = req.metadata().get("authorization")
26        .and_then(|t| t.to_str().ok())
27        .map(|s| s.trim_start_matches("Bearer ").to_string()) {
28            req.extensions_mut().insert(AuthToken(token));
29    }
30    Ok(req)
31}
32
33fn get_token<T>(req: &Request<T>) -> Option<String> {
34    if let Some(token) = req.extensions().get::<AuthToken>() {
35        return Some(token.0.clone());
36    }
37    req.metadata().get("authorization")
38       .and_then(|t| t.to_str().ok())
39       .map(|s| s.trim_start_matches("Bearer ").to_string())
40}
41
42pub struct MySemanticEngine {
43    pub storage_path: String,
44    pub stores: DashMap<String, Arc<SynapseStore>>,
45    pub auth: Arc<NamespaceAuth>,
46    pub audit: Arc<InferenceAudit>,
47}
48
49impl MySemanticEngine {
50    pub fn new(storage_path: &str) -> Self {
51        let auth = Arc::new(NamespaceAuth::new());
52        auth.load_from_env();
53
54        Self {
55            storage_path: storage_path.to_string(),
56            stores: DashMap::new(),
57            auth,
58            audit: Arc::new(InferenceAudit::new()),
59        }
60    }
61
62    pub fn get_store(&self, namespace: &str) -> Result<Arc<SynapseStore>, Status> {
63        if let Some(store) = self.stores.get(namespace) {
64            return Ok(store.clone());
65        }
66
67        let store = SynapseStore::open(namespace, &self.storage_path).map_err(|e| {
68            Status::internal(format!(
69                "Failed to open store for namespace '{}': {}",
70                namespace, e
71            ))
72        })?;
73
74        let store_arc = Arc::new(store);
75        self.stores.insert(namespace.to_string(), store_arc.clone());
76        Ok(store_arc)
77    }
78}
79
80#[tonic::async_trait]
81impl SemanticEngine for MySemanticEngine {
82    async fn ingest_triples(
83        &self,
84        request: Request<IngestRequest>,
85    ) -> Result<Response<IngestResponse>, Status> {
86        // Auth check (Write permission)
87        let token = get_token(&request);
88        let req = request.into_inner();
89        let namespace = if req.namespace.is_empty() {
90            "default"
91        } else {
92            &req.namespace
93        };
94
95        if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
96            return Err(Status::permission_denied(e));
97        }
98
99        let store = self.get_store(namespace)?;
100
101        // Log provenance for audit
102        let timestamp = chrono::Utc::now().to_rfc3339();
103        let triple_count = req.triples.len();
104        let mut sources: Vec<String> = Vec::new();
105
106        let triples: Vec<(String, String, String)> = req
107            .triples
108            .into_iter()
109            .map(|t| {
110                // Capture provenance sources for logging
111                if let Some(ref prov) = t.provenance {
112                    if !prov.source.is_empty() && !sources.contains(&prov.source) {
113                        sources.push(prov.source.clone());
114                    }
115                }
116                (t.subject, t.predicate, t.object)
117            })
118            .collect();
119
120        match store.ingest_triples(triples).await {
121            Ok((added, _)) => {
122                // Log ingestion for audit trail
123                eprintln!(
124                    "INGEST [{timestamp}] namespace={namespace} triples={triple_count} added={added} sources={:?}",
125                    sources
126                );
127                Ok(Response::new(IngestResponse {
128                    nodes_added: added,
129                    edges_added: added,
130                }))
131            }
132            Err(e) => Err(Status::internal(e.to_string())),
133        }
134    }
135
136    async fn ingest_file(
137        &self,
138        request: Request<IngestFileRequest>,
139    ) -> Result<Response<IngestResponse>, Status> {
140        // Auth check (Write permission) - previously missing? or just implicit?
141        // Note: The original code didn't check auth for ingest_file!
142        // Adding it now for consistency as we are touching auth.
143        let token = get_token(&request);
144        let req = request.into_inner();
145        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
146
147        if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
148             return Err(Status::permission_denied(e));
149        }
150        let store = self.get_store(namespace)?;
151
152        let engine = IngestionEngine::new(store);
153        let path = Path::new(&req.file_path);
154
155        match engine.ingest_file(path, namespace).await {
156            Ok(count) => Ok(Response::new(IngestResponse {
157                nodes_added: count,
158                edges_added: count,
159            })),
160            Err(e) => Err(Status::internal(e.to_string())),
161        }
162    }
163
164    async fn get_neighbors(
165        &self,
166        request: Request<NodeRequest>,
167    ) -> Result<Response<NeighborResponse>, Status> {
168        let req = request.into_inner();
169        let namespace = if req.namespace.is_empty() {
170            "default"
171        } else {
172            &req.namespace
173        };
174        let store = self.get_store(namespace)?;
175
176        let direction = if req.direction.is_empty() {
177            "outgoing"
178        } else {
179            &req.direction
180        };
181        let edge_filter = if req.edge_filter.is_empty() {
182            None
183        } else {
184            Some(req.edge_filter.as_str())
185        };
186        let max_depth = if req.depth == 0 {
187            1
188        } else {
189            req.depth as usize
190        };
191        let limit_per_layer = if req.limit_per_layer == 0 {
192            usize::MAX
193        } else {
194            req.limit_per_layer as usize
195        };
196
197        let mut neighbors = Vec::new();
198        let mut visited = std::collections::HashSet::new();
199        let mut current_frontier = Vec::new();
200
201        // Start with the initial node
202        if let Some(start_uri) = store.get_uri(req.node_id) {
203            current_frontier.push(start_uri.clone());
204            visited.insert(start_uri);
205        }
206
207        // BFS traversal up to max_depth
208        for current_depth in 1..=max_depth {
209            let mut next_frontier = Vec::new();
210            let mut layer_count = 0;
211            let score = 1.0 / current_depth as f32; // Path scoring: closer = higher
212
213            for uri in &current_frontier {
214                if layer_count >= limit_per_layer {
215                    break;
216                }
217
218                // Query outgoing edges (URI as subject)
219                if direction == "outgoing" || direction == "both" {
220                    if let Ok(subj) = oxigraph::model::NamedNodeRef::new(uri) {
221                        for quad in
222                            store
223                                .store
224                                .quads_for_pattern(Some(subj.into()), None, None, None)
225                        {
226                            if layer_count >= limit_per_layer {
227                                break;
228                            }
229                            if let Ok(q) = quad {
230                                let pred = q.predicate.to_string();
231                                // Apply edge filter if specified
232                                if let Some(filter) = edge_filter {
233                                    if !pred.contains(filter) {
234                                        continue;
235                                    }
236                                }
237                                let obj_uri = q.object.to_string();
238                                if !visited.contains(&obj_uri) {
239                                    visited.insert(obj_uri.clone());
240                                    let obj_id = store.get_or_create_id(&obj_uri);
241                                    neighbors.push(Neighbor {
242                                        node_id: obj_id,
243                                        edge_type: pred,
244                                        uri: obj_uri.clone(),
245                                        direction: "outgoing".to_string(),
246                                        depth: current_depth as u32,
247                                        score,
248                                    });
249                                    next_frontier.push(obj_uri);
250                                    layer_count += 1;
251                                }
252                            }
253                        }
254                    }
255                }
256
257                // Query incoming edges (URI as object)
258                if direction == "incoming" || direction == "both" {
259                    if let Ok(obj) = oxigraph::model::NamedNodeRef::new(uri) {
260                        for quad in
261                            store
262                                .store
263                                .quads_for_pattern(None, None, Some(obj.into()), None)
264                        {
265                            if layer_count >= limit_per_layer {
266                                break;
267                            }
268                            if let Ok(q) = quad {
269                                let pred = q.predicate.to_string();
270                                // Apply edge filter if specified
271                                if let Some(filter) = edge_filter {
272                                    if !pred.contains(filter) {
273                                        continue;
274                                    }
275                                }
276                                let subj_uri = q.subject.to_string();
277                                if !visited.contains(&subj_uri) {
278                                    visited.insert(subj_uri.clone());
279                                    let subj_id = store.get_or_create_id(&subj_uri);
280                                    neighbors.push(Neighbor {
281                                        node_id: subj_id,
282                                        edge_type: pred,
283                                        uri: subj_uri.clone(),
284                                        direction: "incoming".to_string(),
285                                        depth: current_depth as u32,
286                                        score,
287                                    });
288                                    next_frontier.push(subj_uri);
289                                    layer_count += 1;
290                                }
291                            }
292                        }
293                    }
294                }
295            }
296
297            current_frontier = next_frontier;
298            if current_frontier.is_empty() {
299                break;
300            }
301        }
302
303        // Sort by score (highest first)
304        neighbors.sort_by(|a, b| {
305            b.score
306                .partial_cmp(&a.score)
307                .unwrap_or(std::cmp::Ordering::Equal)
308        });
309
310        Ok(Response::new(NeighborResponse { neighbors }))
311    }
312
313    async fn search(
314        &self,
315        request: Request<SearchRequest>,
316    ) -> Result<Response<SearchResponse>, Status> {
317        let req = request.into_inner();
318        let namespace = if req.namespace.is_empty() {
319            "default"
320        } else {
321            &req.namespace
322        };
323        let store = self.get_store(namespace)?;
324
325        match store.hybrid_search(&req.query, req.limit as usize, 0).await {
326            Ok(results) => {
327                let grpc_results = results
328                    .into_iter()
329                    .enumerate()
330                    .map(|(idx, (uri, score))| SearchResult {
331                        node_id: idx as u32,
332                        score,
333                        content: uri.clone(),
334                        uri,
335                    })
336                    .collect();
337                Ok(Response::new(SearchResponse {
338                    results: grpc_results,
339                }))
340            }
341            Err(e) => Err(Status::internal(e.to_string())),
342        }
343    }
344
345    async fn resolve_id(
346        &self,
347        request: Request<ResolveRequest>,
348    ) -> Result<Response<ResolveResponse>, Status> {
349        let req = request.into_inner();
350        let namespace = if req.namespace.is_empty() {
351            "default"
352        } else {
353            &req.namespace
354        };
355        let store = self.get_store(namespace)?;
356
357        // Look up the URI in our mapping
358        let uri_to_id = store.uri_to_id.read().unwrap();
359        if let Some(&node_id) = uri_to_id.get(&req.content) {
360            Ok(Response::new(ResolveResponse {
361                node_id,
362                found: true,
363            }))
364        } else {
365            Ok(Response::new(ResolveResponse {
366                node_id: 0,
367                found: false,
368            }))
369        }
370    }
371
372    async fn get_all_triples(
373        &self,
374        request: Request<EmptyRequest>,
375    ) -> Result<Response<TriplesResponse>, Status> {
376        let req = request.into_inner();
377        let namespace = if req.namespace.is_empty() {
378            "default"
379        } else {
380            &req.namespace
381        };
382        let store = self.get_store(namespace)?;
383
384        let mut triples = Vec::new();
385
386        for quad in store.store.iter().map(|q| q.unwrap()) {
387            triples.push(Triple {
388                subject: quad.subject.to_string(),
389                predicate: quad.predicate.to_string(),
390                object: quad.object.to_string(),
391                provenance: Some(Provenance {
392                    source: "oxigraph".to_string(),
393                    timestamp: "".to_string(),
394                    method: "storage".to_string(),
395                }),
396                embedding: vec![],
397            });
398        }
399
400        Ok(Response::new(TriplesResponse { triples }))
401    }
402
403    async fn query_sparql(
404        &self,
405        request: Request<SparqlRequest>,
406    ) -> Result<Response<SparqlResponse>, Status> {
407        let req = request.into_inner();
408        let namespace = if req.namespace.is_empty() {
409            "default"
410        } else {
411            &req.namespace
412        };
413        let store = self.get_store(namespace)?;
414
415        match store.query_sparql(&req.query) {
416            Ok(json) => Ok(Response::new(SparqlResponse { results_json: json })),
417            Err(e) => Err(Status::internal(e.to_string())),
418        }
419    }
420
421    async fn delete_namespace_data(
422        &self,
423        request: Request<EmptyRequest>,
424    ) -> Result<Response<DeleteResponse>, Status> {
425        let req = request.into_inner();
426        let namespace = if req.namespace.is_empty() {
427            "default"
428        } else {
429            &req.namespace
430        };
431
432        // Remove from cache
433        self.stores.remove(namespace);
434
435        // Delete directory
436        let path = Path::new(&self.storage_path).join(namespace);
437        if path.exists() {
438            std::fs::remove_dir_all(path).map_err(|e| Status::internal(e.to_string()))?;
439        }
440
441        Ok(Response::new(DeleteResponse {
442            success: true,
443            message: format!("Deleted namespace '{}'", namespace),
444        }))
445    }
446
447    async fn hybrid_search(
448        &self,
449        request: Request<HybridSearchRequest>,
450    ) -> Result<Response<SearchResponse>, Status> {
451        let req = request.into_inner();
452        let namespace = if req.namespace.is_empty() {
453            "default"
454        } else {
455            &req.namespace
456        };
457        let store = self.get_store(namespace)?;
458
459        let vector_k = req.vector_k as usize;
460        let graph_depth = req.graph_depth;
461
462        let results = match SearchMode::try_from(req.mode) {
463            Ok(SearchMode::VectorOnly) | Ok(SearchMode::Hybrid) => store
464                .hybrid_search(&req.query, vector_k, graph_depth)
465                .await
466                .map_err(|e| Status::internal(format!("Hybrid search failed: {}", e)))?,
467            _ => vec![],
468        };
469
470        let grpc_results = results
471            .into_iter()
472            .enumerate()
473            .map(|(idx, (uri, score))| SearchResult {
474                node_id: idx as u32,
475                score,
476                content: uri.clone(),
477                uri,
478            })
479            .collect();
480
481        Ok(Response::new(SearchResponse {
482            results: grpc_results,
483        }))
484    }
485
486    async fn apply_reasoning(
487        &self,
488        request: Request<ReasoningRequest>,
489    ) -> Result<Response<ReasoningResponse>, Status> {
490        // Auth check (Reason permission)
491        let token = get_token(&request);
492        let req = request.into_inner();
493        let namespace = if req.namespace.is_empty() {
494            "default"
495        } else {
496            &req.namespace
497        };
498
499        if let Err(e) = self.auth.check(token.as_deref(), namespace, "reason") {
500            return Err(Status::permission_denied(e));
501        }
502
503        let store = self.get_store(namespace)?;
504
505        let strategy = match ReasoningStrategy::try_from(req.strategy) {
506            Ok(ReasoningStrategy::Rdfs) => InternalStrategy::RDFS,
507            Ok(ReasoningStrategy::Owlrl) => InternalStrategy::OWLRL,
508            _ => InternalStrategy::None,
509        };
510        let strategy_name = format!("{:?}", strategy);
511
512        let reasoner = SynapseReasoner::new(strategy);
513        let start_triples = store.store.len().unwrap_or(0);
514
515        let response = if req.materialize {
516            match reasoner.materialize(&store.store) {
517                Ok(count) => Ok(Response::new(ReasoningResponse {
518                    success: true,
519                    triples_inferred: count as u32,
520                    message: format!(
521                        "Materialized {} triples in namespace '{}'",
522                        count, namespace
523                    ),
524                })),
525                Err(e) => Err(Status::internal(e.to_string())),
526            }
527        } else {
528            match reasoner.apply(&store.store) {
529                Ok(triples) => Ok(Response::new(ReasoningResponse {
530                    success: true,
531                    triples_inferred: triples.len() as u32,
532                    message: format!(
533                        "Found {} inferred triples in namespace '{}'",
534                        triples.len(),
535                        namespace
536                    ),
537                })),
538                Err(e) => Err(Status::internal(e.to_string())),
539            }
540        };
541
542        // Audit Log
543        if let Ok(ref res) = response {
544            let inferred = res.get_ref().triples_inferred as usize;
545            self.audit.log(
546                namespace,
547                &strategy_name,
548                start_triples,
549                inferred,
550                0, // Duplicates skipped not easily tracked here without changing reasoner return signature
551                vec![], // Sample inferences
552            );
553        }
554
555        response
556    }
557}
558
559pub async fn run_mcp_stdio(
560    engine: Arc<MySemanticEngine>,
561) -> Result<(), Box<dyn std::error::Error>> {
562    let server = crate::mcp_stdio::McpStdioServer::new(engine);
563    server.run().await
564}