Skip to main content

synapse_core/
server.rs

1use dashmap::DashMap;
2use std::sync::Arc;
3use tonic::{Request, Response, Status};
4
5pub mod proto {
6    tonic::include_proto!("semantic_engine");
7}
8
9use proto::semantic_engine_server::SemanticEngine;
10use proto::*;
11
12use crate::ingest::IngestionEngine;
13use crate::reasoner::{ReasoningStrategy as InternalStrategy, SynapseReasoner};
14use crate::server::proto::{ReasoningStrategy, SearchMode};
15use crate::store::{SynapseStore, IngestTriple};
16use std::path::Path;
17
18use crate::audit::InferenceAudit;
19use crate::auth::NamespaceAuth;
20
21#[derive(Clone)]
22pub struct AuthToken(pub String);
23
24pub fn auth_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
25    if let Some(token) = req.metadata().get("authorization")
26        .and_then(|t| t.to_str().ok())
27        .map(|s| s.trim_start_matches("Bearer ").to_string()) {
28            req.extensions_mut().insert(AuthToken(token));
29    }
30    Ok(req)
31}
32
33fn get_token<T>(req: &Request<T>) -> Option<String> {
34    if let Some(token) = req.extensions().get::<AuthToken>() {
35        return Some(token.0.clone());
36    }
37    req.metadata().get("authorization")
38       .and_then(|t| t.to_str().ok())
39       .map(|s| s.trim_start_matches("Bearer ").to_string())
40}
41
42pub struct MySemanticEngine {
43    pub storage_path: String,
44    pub stores: DashMap<String, Arc<SynapseStore>>,
45    pub auth: Arc<NamespaceAuth>,
46    pub audit: Arc<InferenceAudit>,
47}
48
49impl MySemanticEngine {
50    pub fn new(storage_path: &str) -> Self {
51        let auth = Arc::new(NamespaceAuth::new());
52        auth.load_from_env();
53
54        Self {
55            storage_path: storage_path.to_string(),
56            stores: DashMap::new(),
57            auth,
58            audit: Arc::new(InferenceAudit::new()),
59        }
60    }
61
62    pub fn get_store(&self, namespace: &str) -> Result<Arc<SynapseStore>, Status> {
63        if let Some(store) = self.stores.get(namespace) {
64            return Ok(store.clone());
65        }
66
67        let store = SynapseStore::open(namespace, &self.storage_path).map_err(|e| {
68            Status::internal(format!(
69                "Failed to open store for namespace '{}': {}",
70                namespace, e
71            ))
72        })?;
73
74        let store_arc = Arc::new(store);
75        self.stores.insert(namespace.to_string(), store_arc.clone());
76        Ok(store_arc)
77    }
78}
79
80#[tonic::async_trait]
81impl SemanticEngine for MySemanticEngine {
82    async fn ingest_triples(
83        &self,
84        request: Request<IngestRequest>,
85    ) -> Result<Response<IngestResponse>, Status> {
86        // Auth check (Write permission)
87        let token = get_token(&request);
88        let req = request.into_inner();
89        let namespace = if req.namespace.is_empty() {
90            "default"
91        } else {
92            &req.namespace
93        };
94
95        if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
96            return Err(Status::permission_denied(e));
97        }
98
99        let store = self.get_store(namespace)?;
100
101        // Log provenance for audit
102        let timestamp = chrono::Utc::now().to_rfc3339();
103        let triple_count = req.triples.len();
104        let mut sources: Vec<String> = Vec::new();
105
106        let triples: Vec<IngestTriple> = req
107            .triples
108            .into_iter()
109            .map(|t| {
110                // Capture provenance sources for logging
111                if let Some(ref prov) = t.provenance {
112                    if !prov.source.is_empty() && !sources.contains(&prov.source) {
113                        sources.push(prov.source.clone());
114                    }
115                }
116                IngestTriple {
117                    subject: t.subject,
118                    predicate: t.predicate,
119                    object: t.object,
120                    provenance: t.provenance.map(|p| crate::store::Provenance {
121                        source: p.source,
122                        timestamp: p.timestamp,
123                        method: p.method,
124                    }),
125                }
126            })
127            .collect();
128
129        match store.ingest_triples(triples).await {
130            Ok((added, _)) => {
131                // Log ingestion for audit trail
132                eprintln!(
133                    "INGEST [{timestamp}] namespace={namespace} triples={triple_count} added={added} sources={:?}",
134                    sources
135                );
136                Ok(Response::new(IngestResponse {
137                    nodes_added: added,
138                    edges_added: added,
139                }))
140            }
141            Err(e) => Err(Status::internal(e.to_string())),
142        }
143    }
144
145    async fn ingest_file(
146        &self,
147        request: Request<IngestFileRequest>,
148    ) -> Result<Response<IngestResponse>, Status> {
149        // Auth check (Write permission) - previously missing? or just implicit?
150        // Note: The original code didn't check auth for ingest_file!
151        // Adding it now for consistency as we are touching auth.
152        let token = get_token(&request);
153        let req = request.into_inner();
154        let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
155
156        if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
157             return Err(Status::permission_denied(e));
158        }
159        let store = self.get_store(namespace)?;
160
161        let engine = IngestionEngine::new(store);
162        let path = Path::new(&req.file_path);
163
164        match engine.ingest_file(path, namespace).await {
165            Ok(count) => Ok(Response::new(IngestResponse {
166                nodes_added: count,
167                edges_added: count,
168            })),
169            Err(e) => Err(Status::internal(e.to_string())),
170        }
171    }
172
173    async fn get_neighbors(
174        &self,
175        request: Request<NodeRequest>,
176    ) -> Result<Response<NeighborResponse>, Status> {
177        let token = get_token(&request);
178        let req = request.into_inner();
179        let namespace = if req.namespace.is_empty() {
180            "default"
181        } else {
182            &req.namespace
183        };
184
185        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
186            return Err(Status::permission_denied(e));
187        }
188
189        let store = self.get_store(namespace)?;
190
191        let direction = if req.direction.is_empty() {
192            "outgoing"
193        } else {
194            &req.direction
195        };
196        let edge_filter = if req.edge_filter.is_empty() {
197            None
198        } else {
199            Some(req.edge_filter.as_str())
200        };
201        let node_type_filter = if req.node_type_filter.is_empty() {
202            None
203        } else {
204            Some(req.node_type_filter.as_str())
205        };
206        let max_depth = if req.depth == 0 {
207            1
208        } else {
209            req.depth as usize
210        };
211        let limit_per_layer = if req.limit_per_layer == 0 {
212            usize::MAX
213        } else {
214            req.limit_per_layer as usize
215        };
216
217        let mut neighbors = Vec::new();
218        let mut visited = std::collections::HashSet::new();
219        let mut current_frontier = Vec::new();
220
221        // Start with the initial node
222        if let Some(start_uri) = store.get_uri(req.node_id) {
223            current_frontier.push(start_uri.clone());
224            visited.insert(start_uri);
225        }
226
227        // BFS traversal up to max_depth
228        for current_depth in 1..=max_depth {
229            let mut next_frontier = Vec::new();
230            let mut layer_count = 0;
231            let base_score = 1.0 / current_depth as f32; // Path scoring: closer = higher
232
233            for uri in &current_frontier {
234                if layer_count >= limit_per_layer {
235                    break;
236                }
237
238                // Query outgoing edges (URI as subject)
239                if direction == "outgoing" || direction == "both" {
240                    if let Ok(subj) = oxigraph::model::NamedNodeRef::new(uri) {
241                        for quad in
242                            store
243                                .store
244                                .quads_for_pattern(Some(subj.into()), None, None, None)
245                        {
246                            if layer_count >= limit_per_layer {
247                                break;
248                            }
249                            if let Ok(q) = quad {
250                                let pred = q.predicate.to_string();
251                                // Apply edge filter if specified
252                                if let Some(filter) = edge_filter {
253                                    if !pred.contains(filter) {
254                                        continue;
255                                    }
256                                }
257                                let obj_term = q.object;
258                                let obj_uri = obj_term.to_string();
259
260                                // Node Type Filter Logic
261                                if let Some(type_filter) = node_type_filter {
262                                    let passed = if let oxigraph::model::Term::NamedNode(ref n) = obj_term {
263                                        let rdf_type = oxigraph::model::NamedNodeRef::new("http://www.w3.org/1999/02/22-rdf-syntax-ns#type").unwrap();
264                                        if let Ok(target_type) = oxigraph::model::NamedNodeRef::new(type_filter) {
265                                            store.store.quads_for_pattern(
266                                                Some(n.into()),
267                                                Some(rdf_type.into()),
268                                                Some(target_type.into()),
269                                                None
270                                            ).next().is_some()
271                                        } else {
272                                            false
273                                        }
274                                    } else {
275                                        false
276                                    };
277                                    if !passed { continue; }
278                                }
279
280                                let clean_uri = match &obj_term {
281                                    oxigraph::model::Term::NamedNode(n) => n.as_str(),
282                                    _ => &obj_uri,
283                                };
284
285                                if !visited.contains(&obj_uri) {
286                                    visited.insert(obj_uri.clone());
287                                    let obj_id = store.get_or_create_id(&obj_uri);
288
289                                    let mut neighbor_score = base_score;
290                                    if req.scoring_strategy == "degree" {
291                                        let degree = store.get_degree(clean_uri);
292                                        // Penalize if degree > 1
293                                        if degree > 1 {
294                                             neighbor_score /= (degree as f32).ln().max(1.0);
295                                        }
296                                    }
297
298                                    neighbors.push(Neighbor {
299                                        node_id: obj_id,
300                                        edge_type: pred,
301                                        uri: obj_uri.clone(),
302                                        direction: "outgoing".to_string(),
303                                        depth: current_depth as u32,
304                                        score: neighbor_score,
305                                    });
306                                    next_frontier.push(obj_uri);
307                                    layer_count += 1;
308                                }
309                            }
310                        }
311                    }
312                }
313
314                // Query incoming edges (URI as object)
315                if direction == "incoming" || direction == "both" {
316                    if let Ok(obj) = oxigraph::model::NamedNodeRef::new(uri) {
317                        for quad in
318                            store
319                                .store
320                                .quads_for_pattern(None, None, Some(obj.into()), None)
321                        {
322                            if layer_count >= limit_per_layer {
323                                break;
324                            }
325                            if let Ok(q) = quad {
326                                let pred = q.predicate.to_string();
327                                // Apply edge filter if specified
328                                if let Some(filter) = edge_filter {
329                                    if !pred.contains(filter) {
330                                        continue;
331                                    }
332                                }
333                                let subj_term = q.subject;
334                                let subj_uri = subj_term.to_string();
335
336                                // Node Type Filter Logic
337                                if let Some(type_filter) = node_type_filter {
338                                    let passed = if let oxigraph::model::Subject::NamedNode(ref n) = subj_term {
339                                        let rdf_type = oxigraph::model::NamedNodeRef::new("http://www.w3.org/1999/02/22-rdf-syntax-ns#type").unwrap();
340                                        if let Ok(target_type) = oxigraph::model::NamedNodeRef::new(type_filter) {
341                                            store.store.quads_for_pattern(
342                                                Some(n.into()),
343                                                Some(rdf_type.into()),
344                                                Some(target_type.into()),
345                                                None
346                                            ).next().is_some()
347                                        } else {
348                                            false
349                                        }
350                                    } else {
351                                        false
352                                    };
353                                    if !passed { continue; }
354                                }
355
356                                let clean_uri = match &subj_term {
357                                    oxigraph::model::Subject::NamedNode(n) => n.as_str(),
358                                    _ => &subj_uri,
359                                };
360
361                                if !visited.contains(&subj_uri) {
362                                    visited.insert(subj_uri.clone());
363                                    let subj_id = store.get_or_create_id(&subj_uri);
364
365                                    let mut neighbor_score = base_score;
366                                    if req.scoring_strategy == "degree" {
367                                        let degree = store.get_degree(clean_uri);
368                                        // Penalize if degree > 1
369                                        if degree > 1 {
370                                             neighbor_score /= (degree as f32).ln().max(1.0);
371                                        }
372                                    }
373
374                                    neighbors.push(Neighbor {
375                                        node_id: subj_id,
376                                        edge_type: pred,
377                                        uri: subj_uri.clone(),
378                                        direction: "incoming".to_string(),
379                                        depth: current_depth as u32,
380                                        score: neighbor_score,
381                                    });
382                                    next_frontier.push(subj_uri);
383                                    layer_count += 1;
384                                }
385                            }
386                        }
387                    }
388                }
389            }
390
391            current_frontier = next_frontier;
392            if current_frontier.is_empty() {
393                break;
394            }
395        }
396
397        // Sort by score (highest first)
398        neighbors.sort_by(|a, b| {
399            b.score
400                .partial_cmp(&a.score)
401                .unwrap_or(std::cmp::Ordering::Equal)
402        });
403
404        Ok(Response::new(NeighborResponse { neighbors }))
405    }
406
407    async fn search(
408        &self,
409        request: Request<SearchRequest>,
410    ) -> Result<Response<SearchResponse>, Status> {
411        let token = get_token(&request);
412        let req = request.into_inner();
413        let namespace = if req.namespace.is_empty() {
414            "default"
415        } else {
416            &req.namespace
417        };
418
419        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
420            return Err(Status::permission_denied(e));
421        }
422
423        let store = self.get_store(namespace)?;
424
425        match store.hybrid_search(&req.query, req.limit as usize, 0).await {
426            Ok(results) => {
427                let grpc_results = results
428                    .into_iter()
429                    .enumerate()
430                    .map(|(idx, (uri, score))| SearchResult {
431                        node_id: idx as u32,
432                        score,
433                        content: uri.clone(),
434                        uri,
435                    })
436                    .collect();
437                Ok(Response::new(SearchResponse {
438                    results: grpc_results,
439                }))
440            }
441            Err(e) => Err(Status::internal(e.to_string())),
442        }
443    }
444
445    async fn resolve_id(
446        &self,
447        request: Request<ResolveRequest>,
448    ) -> Result<Response<ResolveResponse>, Status> {
449        let token = get_token(&request);
450        let req = request.into_inner();
451        let namespace = if req.namespace.is_empty() {
452            "default"
453        } else {
454            &req.namespace
455        };
456
457        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
458            return Err(Status::permission_denied(e));
459        }
460
461        let store = self.get_store(namespace)?;
462
463        let uri = store.ensure_uri(&req.content);
464
465        // Look up the URI in our mapping
466        let uri_to_id = store.uri_to_id.read().unwrap();
467        if let Some(&node_id) = uri_to_id.get(&uri) {
468            Ok(Response::new(ResolveResponse {
469                node_id,
470                found: true,
471            }))
472        } else {
473            Ok(Response::new(ResolveResponse {
474                node_id: 0,
475                found: false,
476            }))
477        }
478    }
479
480    async fn get_all_triples(
481        &self,
482        request: Request<EmptyRequest>,
483    ) -> Result<Response<TriplesResponse>, Status> {
484        let token = get_token(&request);
485        let req = request.into_inner();
486        let namespace = if req.namespace.is_empty() {
487            "default"
488        } else {
489            &req.namespace
490        };
491
492        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
493            return Err(Status::permission_denied(e));
494        }
495
496        let store = self.get_store(namespace)?;
497
498        let mut triples = Vec::new();
499
500        for quad in store.store.iter().map(|q| q.unwrap()) {
501            let s = quad.subject.to_string();
502            let p = quad.predicate.to_string();
503            let o = quad.object.to_string();
504
505            // Clean up NTriples formatting (<uri> -> uri)
506            let clean_s = if s.starts_with('<') && s.ends_with('>') { s[1..s.len()-1].to_string() } else { s };
507            let clean_p = if p.starts_with('<') && p.ends_with('>') { p[1..p.len()-1].to_string() } else { p };
508            let clean_o = if o.starts_with('<') && o.ends_with('>') { o[1..o.len()-1].to_string() } else { o };
509
510            triples.push(Triple {
511                subject: clean_s,
512                predicate: clean_p,
513                object: clean_o,
514                provenance: Some(Provenance {
515                    source: "oxigraph".to_string(),
516                    timestamp: "".to_string(),
517                    method: "storage".to_string(),
518                }),
519                embedding: vec![],
520            });
521        }
522
523        Ok(Response::new(TriplesResponse { triples }))
524    }
525
526    async fn query_sparql(
527        &self,
528        request: Request<SparqlRequest>,
529    ) -> Result<Response<SparqlResponse>, Status> {
530        let token = get_token(&request);
531        let req = request.into_inner();
532        let namespace = if req.namespace.is_empty() {
533            "default"
534        } else {
535            &req.namespace
536        };
537
538        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
539            return Err(Status::permission_denied(e));
540        }
541
542        let store = self.get_store(namespace)?;
543
544        match store.query_sparql(&req.query) {
545            Ok(json) => Ok(Response::new(SparqlResponse { results_json: json })),
546            Err(e) => Err(Status::internal(e.to_string())),
547        }
548    }
549
550    async fn delete_namespace_data(
551        &self,
552        request: Request<EmptyRequest>,
553    ) -> Result<Response<DeleteResponse>, Status> {
554        let token = get_token(&request);
555        let req = request.into_inner();
556        let namespace = if req.namespace.is_empty() {
557            "default"
558        } else {
559            &req.namespace
560        };
561
562        if let Err(e) = self.auth.check(token.as_deref(), namespace, "delete") {
563            return Err(Status::permission_denied(e));
564        }
565
566        // Remove from cache
567        self.stores.remove(namespace);
568
569        // Delete directory
570        let path = Path::new(&self.storage_path).join(namespace);
571        if path.exists() {
572            std::fs::remove_dir_all(path).map_err(|e| Status::internal(e.to_string()))?;
573        }
574
575        Ok(Response::new(DeleteResponse {
576            success: true,
577            message: format!("Deleted namespace '{}'", namespace),
578        }))
579    }
580
581    async fn hybrid_search(
582        &self,
583        request: Request<HybridSearchRequest>,
584    ) -> Result<Response<SearchResponse>, Status> {
585        let token = get_token(&request);
586        let req = request.into_inner();
587        let namespace = if req.namespace.is_empty() {
588            "default"
589        } else {
590            &req.namespace
591        };
592
593        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
594            return Err(Status::permission_denied(e));
595        }
596
597        let store = self.get_store(namespace)?;
598
599        let vector_k = req.vector_k as usize;
600        let graph_depth = req.graph_depth;
601
602        let results = match SearchMode::try_from(req.mode) {
603            Ok(SearchMode::VectorOnly) | Ok(SearchMode::Hybrid) => store
604                .hybrid_search(&req.query, vector_k, graph_depth)
605                .await
606                .map_err(|e| Status::internal(format!("Hybrid search failed: {}", e)))?,
607            _ => vec![],
608        };
609
610        let grpc_results = results
611            .into_iter()
612            .enumerate()
613            .map(|(idx, (uri, score))| SearchResult {
614                node_id: idx as u32,
615                score,
616                content: uri.clone(),
617                uri,
618            })
619            .collect();
620
621        Ok(Response::new(SearchResponse {
622            results: grpc_results,
623        }))
624    }
625
626    async fn apply_reasoning(
627        &self,
628        request: Request<ReasoningRequest>,
629    ) -> Result<Response<ReasoningResponse>, Status> {
630        // Auth check (Reason permission)
631        let token = get_token(&request);
632        let req = request.into_inner();
633        let namespace = if req.namespace.is_empty() {
634            "default"
635        } else {
636            &req.namespace
637        };
638
639        if let Err(e) = self.auth.check(token.as_deref(), namespace, "reason") {
640            return Err(Status::permission_denied(e));
641        }
642
643        let store = self.get_store(namespace)?;
644
645        let strategy = match ReasoningStrategy::try_from(req.strategy) {
646            Ok(ReasoningStrategy::Rdfs) => InternalStrategy::RDFS,
647            Ok(ReasoningStrategy::Owlrl) => InternalStrategy::OWLRL,
648            _ => InternalStrategy::None,
649        };
650        let strategy_name = format!("{:?}", strategy);
651
652        let reasoner = SynapseReasoner::new(strategy);
653        let start_triples = store.store.len().unwrap_or(0);
654
655        let response = if req.materialize {
656            match reasoner.materialize(&store.store) {
657                Ok(count) => Ok(Response::new(ReasoningResponse {
658                    success: true,
659                    triples_inferred: count as u32,
660                    message: format!(
661                        "Materialized {} triples in namespace '{}'",
662                        count, namespace
663                    ),
664                })),
665                Err(e) => Err(Status::internal(e.to_string())),
666            }
667        } else {
668            match reasoner.apply(&store.store) {
669                Ok(triples) => Ok(Response::new(ReasoningResponse {
670                    success: true,
671                    triples_inferred: triples.len() as u32,
672                    message: format!(
673                        "Found {} inferred triples in namespace '{}'",
674                        triples.len(),
675                        namespace
676                    ),
677                })),
678                Err(e) => Err(Status::internal(e.to_string())),
679            }
680        };
681
682        // Audit Log
683        if let Ok(ref res) = response {
684            let inferred = res.get_ref().triples_inferred as usize;
685            self.audit.log(
686                namespace,
687                &strategy_name,
688                start_triples,
689                inferred,
690                0, // Duplicates skipped not easily tracked here without changing reasoner return signature
691                vec![], // Sample inferences
692            );
693        }
694
695        response
696    }
697}
698
699pub async fn run_mcp_stdio(
700    engine: Arc<MySemanticEngine>,
701) -> Result<(), Box<dyn std::error::Error>> {
702    let server = crate::mcp_stdio::McpStdioServer::new(engine);
703    server.run().await
704}