Skip to main content

synapse_core/
server.rs

1use dashmap::DashMap;
2use std::sync::Arc;
3use tonic::{Request, Response, Status};
4
5pub mod proto {
6    tonic::include_proto!("semantic_engine");
7}
8
9use proto::semantic_engine_server::SemanticEngine;
10use proto::*;
11
12use crate::ingest::IngestionEngine;
13use crate::reasoner::{ReasoningStrategy as InternalStrategy, SynapseReasoner};
14use crate::server::proto::{ReasoningStrategy, SearchMode};
15use crate::store::{IngestTriple, SynapseStore};
16use std::path::Path;
17
18use crate::audit::InferenceAudit;
19use crate::auth::NamespaceAuth;
20
21#[derive(Clone)]
22pub struct AuthToken(pub String);
23
24#[allow(clippy::result_large_err)]
25pub fn auth_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
26    if let Some(token) = req
27        .metadata()
28        .get("authorization")
29        .and_then(|t| t.to_str().ok())
30        .map(|s| s.trim_start_matches("Bearer ").to_string())
31    {
32        req.extensions_mut().insert(AuthToken(token));
33    }
34    Ok(req)
35}
36
37fn get_token<T>(req: &Request<T>) -> Option<String> {
38    if let Some(token) = req.extensions().get::<AuthToken>() {
39        return Some(token.0.clone());
40    }
41    req.metadata()
42        .get("authorization")
43        .and_then(|t| t.to_str().ok())
44        .map(|s| s.trim_start_matches("Bearer ").to_string())
45}
46
47#[derive(Clone)]
48pub struct MySemanticEngine {
49    pub storage_path: String,
50    pub stores: Arc<DashMap<String, Arc<SynapseStore>>>,
51    pub auth: Arc<NamespaceAuth>,
52    pub audit: Arc<InferenceAudit>,
53}
54
55impl MySemanticEngine {
56    pub fn new(storage_path: &str) -> Self {
57        let auth = Arc::new(NamespaceAuth::new());
58        auth.load_from_env();
59
60        Self {
61            storage_path: storage_path.to_string(),
62            stores: Arc::new(DashMap::new()),
63            auth,
64            audit: Arc::new(InferenceAudit::new()),
65        }
66    }
67
68    pub async fn shutdown(&self) {
69        eprintln!("Shutting down... flushing {} stores", self.stores.len());
70        for entry in self.stores.iter() {
71            let store = entry.value();
72            if let Err(e) = store.flush() {
73                eprintln!("Failed to flush store '{}': {}", entry.key(), e);
74            }
75        }
76        eprintln!("Shutdown complete.");
77    }
78
79    #[allow(clippy::result_large_err)]
80    pub fn get_store(&self, namespace: &str) -> Result<Arc<SynapseStore>, Status> {
81        if let Some(store) = self.stores.get(namespace) {
82            return Ok(store.clone());
83        }
84
85        let store = SynapseStore::open(namespace, &self.storage_path).map_err(|e| {
86            Status::internal(format!(
87                "Failed to open store for namespace '{}': {}",
88                namespace, e
89            ))
90        })?;
91
92        let store_arc = Arc::new(store);
93        self.stores.insert(namespace.to_string(), store_arc.clone());
94        Ok(store_arc)
95    }
96}
97
98#[tonic::async_trait]
99impl SemanticEngine for MySemanticEngine {
100    async fn ingest_triples(
101        &self,
102        request: Request<IngestRequest>,
103    ) -> Result<Response<IngestResponse>, Status> {
104        // Auth check (Write permission)
105        let token = get_token(&request);
106        let req = request.into_inner();
107        let namespace = if req.namespace.is_empty() {
108            "default"
109        } else {
110            &req.namespace
111        };
112
113        if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
114            return Err(Status::permission_denied(e));
115        }
116
117        let store = self.get_store(namespace)?;
118
119        // Log provenance for audit
120        let timestamp = chrono::Utc::now().to_rfc3339();
121        let triple_count = req.triples.len();
122        let mut sources: Vec<String> = Vec::new();
123
124        let triples: Vec<IngestTriple> = req
125            .triples
126            .into_iter()
127            .map(|t| {
128                // Capture provenance sources for logging
129                if let Some(ref prov) = t.provenance {
130                    if !prov.source.is_empty() && !sources.contains(&prov.source) {
131                        sources.push(prov.source.clone());
132                    }
133                }
134                IngestTriple {
135                    subject: t.subject,
136                    predicate: t.predicate,
137                    object: t.object,
138                    provenance: t.provenance.map(|p| crate::store::Provenance {
139                        source: p.source,
140                        timestamp: p.timestamp,
141                        method: p.method,
142                    }),
143                }
144            })
145            .collect();
146
147        match store.ingest_triples(triples).await {
148            Ok((added, _)) => {
149                // Log ingestion for audit trail
150                eprintln!(
151                    "INGEST [{timestamp}] namespace={namespace} triples={triple_count} added={added} sources={:?}",
152                    sources
153                );
154                Ok(Response::new(IngestResponse {
155                    nodes_added: added,
156                    edges_added: added,
157                }))
158            }
159            Err(e) => Err(Status::internal(e.to_string())),
160        }
161    }
162
163    async fn ingest_file(
164        &self,
165        request: Request<IngestFileRequest>,
166    ) -> Result<Response<IngestResponse>, Status> {
167        // Auth check (Write permission) - previously missing? or just implicit?
168        // Note: The original code didn't check auth for ingest_file!
169        // Adding it now for consistency as we are touching auth.
170        let token = get_token(&request);
171        let req = request.into_inner();
172        let namespace = if req.namespace.is_empty() {
173            "default"
174        } else {
175            &req.namespace
176        };
177
178        if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
179            return Err(Status::permission_denied(e));
180        }
181        let store = self.get_store(namespace)?;
182
183        let engine = IngestionEngine::new(store);
184        let path = Path::new(&req.file_path);
185
186        match engine.ingest_file(path, namespace).await {
187            Ok(count) => Ok(Response::new(IngestResponse {
188                nodes_added: count,
189                edges_added: count,
190            })),
191            Err(e) => Err(Status::internal(e.to_string())),
192        }
193    }
194
195    async fn get_neighbors(
196        &self,
197        request: Request<NodeRequest>,
198    ) -> Result<Response<NeighborResponse>, Status> {
199        let token = get_token(&request);
200        let req = request.into_inner();
201        let namespace = if req.namespace.is_empty() {
202            "default"
203        } else {
204            &req.namespace
205        };
206
207        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
208            return Err(Status::permission_denied(e));
209        }
210
211        let store = self.get_store(namespace)?;
212
213        let direction = if req.direction.is_empty() {
214            "outgoing"
215        } else {
216            &req.direction
217        };
218        let edge_filter = if req.edge_filter.is_empty() {
219            None
220        } else {
221            Some(req.edge_filter.as_str())
222        };
223        let node_type_filter = if req.node_type_filter.is_empty() {
224            None
225        } else {
226            Some(req.node_type_filter.as_str())
227        };
228        let max_depth = if req.depth == 0 {
229            1
230        } else {
231            req.depth as usize
232        };
233        let limit_per_layer = if req.limit_per_layer == 0 {
234            usize::MAX
235        } else {
236            req.limit_per_layer as usize
237        };
238
239        let mut neighbors = Vec::new();
240        let mut visited = std::collections::HashSet::new();
241        let mut current_frontier = Vec::new();
242
243        // Start with the initial node
244        if let Some(start_uri) = store.get_uri(req.node_id) {
245            current_frontier.push(start_uri.clone());
246            visited.insert(start_uri);
247        }
248
249        // BFS traversal up to max_depth
250        for current_depth in 1..=max_depth {
251            let mut next_frontier = Vec::new();
252            let mut layer_count = 0;
253            let base_score = 1.0 / current_depth as f32; // Path scoring: closer = higher
254
255            for uri in &current_frontier {
256                if layer_count >= limit_per_layer {
257                    break;
258                }
259
260                // Query outgoing edges (URI as subject)
261                if direction == "outgoing" || direction == "both" {
262                    if let Ok(subj) = oxigraph::model::NamedNodeRef::new(uri) {
263                        for quad in
264                            store
265                                .store
266                                .quads_for_pattern(Some(subj.into()), None, None, None)
267                        {
268                            if layer_count >= limit_per_layer {
269                                break;
270                            }
271                            if let Ok(q) = quad {
272                                let pred = q.predicate.to_string();
273                                // Apply edge filter if specified
274                                if let Some(filter) = edge_filter {
275                                    if !pred.contains(filter) {
276                                        continue;
277                                    }
278                                }
279                                let obj_term = q.object;
280                                let obj_uri = obj_term.to_string();
281
282                                // Node Type Filter Logic
283                                if let Some(type_filter) = node_type_filter {
284                                    let passed =
285                                        if let oxigraph::model::Term::NamedNode(ref n) = obj_term {
286                                            let rdf_type = oxigraph::model::NamedNodeRef::new(
287                                                "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
288                                            )
289                                            .unwrap();
290                                            if let Ok(target_type) =
291                                                oxigraph::model::NamedNodeRef::new(type_filter)
292                                            {
293                                                store
294                                                    .store
295                                                    .quads_for_pattern(
296                                                        Some(n.into()),
297                                                        Some(rdf_type),
298                                                        Some(target_type.into()),
299                                                        None,
300                                                    )
301                                                    .next()
302                                                    .is_some()
303                                            } else {
304                                                false
305                                            }
306                                        } else {
307                                            false
308                                        };
309                                    if !passed {
310                                        continue;
311                                    }
312                                }
313
314                                let clean_uri = match &obj_term {
315                                    oxigraph::model::Term::NamedNode(n) => n.as_str(),
316                                    _ => &obj_uri,
317                                };
318
319                                if !visited.contains(&obj_uri) {
320                                    visited.insert(obj_uri.clone());
321                                    let obj_id = store.get_or_create_id(&obj_uri);
322
323                                    let mut neighbor_score = base_score;
324                                    if req.scoring_strategy == "degree" {
325                                        let degree = store.get_degree(clean_uri);
326                                        // Penalize if degree > 1
327                                        if degree > 1 {
328                                            neighbor_score /= (degree as f32).ln().max(1.0);
329                                        }
330                                    }
331
332                                    neighbors.push(Neighbor {
333                                        node_id: obj_id,
334                                        edge_type: pred,
335                                        uri: obj_uri.clone(),
336                                        direction: "outgoing".to_string(),
337                                        depth: current_depth as u32,
338                                        score: neighbor_score,
339                                    });
340                                    next_frontier.push(obj_uri);
341                                    layer_count += 1;
342                                }
343                            }
344                        }
345                    }
346                }
347
348                // Query incoming edges (URI as object)
349                if direction == "incoming" || direction == "both" {
350                    if let Ok(obj) = oxigraph::model::NamedNodeRef::new(uri) {
351                        for quad in
352                            store
353                                .store
354                                .quads_for_pattern(None, None, Some(obj.into()), None)
355                        {
356                            if layer_count >= limit_per_layer {
357                                break;
358                            }
359                            if let Ok(q) = quad {
360                                let pred = q.predicate.to_string();
361                                // Apply edge filter if specified
362                                if let Some(filter) = edge_filter {
363                                    if !pred.contains(filter) {
364                                        continue;
365                                    }
366                                }
367                                let subj_term = q.subject;
368                                let subj_uri = subj_term.to_string();
369
370                                // Node Type Filter Logic
371                                if let Some(type_filter) = node_type_filter {
372                                    let passed = if let oxigraph::model::Subject::NamedNode(ref n) =
373                                        subj_term
374                                    {
375                                        let rdf_type = oxigraph::model::NamedNodeRef::new(
376                                            "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
377                                        )
378                                        .unwrap();
379                                        if let Ok(target_type) =
380                                            oxigraph::model::NamedNodeRef::new(type_filter)
381                                        {
382                                            store
383                                                .store
384                                                .quads_for_pattern(
385                                                    Some(n.into()),
386                                                    Some(rdf_type),
387                                                    Some(target_type.into()),
388                                                    None,
389                                                )
390                                                .next()
391                                                .is_some()
392                                        } else {
393                                            false
394                                        }
395                                    } else {
396                                        false
397                                    };
398                                    if !passed {
399                                        continue;
400                                    }
401                                }
402
403                                let clean_uri = match &subj_term {
404                                    oxigraph::model::Subject::NamedNode(n) => n.as_str(),
405                                    _ => &subj_uri,
406                                };
407
408                                if !visited.contains(&subj_uri) {
409                                    visited.insert(subj_uri.clone());
410                                    let subj_id = store.get_or_create_id(&subj_uri);
411
412                                    let mut neighbor_score = base_score;
413                                    if req.scoring_strategy == "degree" {
414                                        let degree = store.get_degree(clean_uri);
415                                        // Penalize if degree > 1
416                                        if degree > 1 {
417                                            neighbor_score /= (degree as f32).ln().max(1.0);
418                                        }
419                                    }
420
421                                    neighbors.push(Neighbor {
422                                        node_id: subj_id,
423                                        edge_type: pred,
424                                        uri: subj_uri.clone(),
425                                        direction: "incoming".to_string(),
426                                        depth: current_depth as u32,
427                                        score: neighbor_score,
428                                    });
429                                    next_frontier.push(subj_uri);
430                                    layer_count += 1;
431                                }
432                            }
433                        }
434                    }
435                }
436            }
437
438            current_frontier = next_frontier;
439            if current_frontier.is_empty() {
440                break;
441            }
442        }
443
444        // Sort by score (highest first)
445        neighbors.sort_by(|a, b| {
446            b.score
447                .partial_cmp(&a.score)
448                .unwrap_or(std::cmp::Ordering::Equal)
449        });
450
451        Ok(Response::new(NeighborResponse { neighbors }))
452    }
453
454    async fn search(
455        &self,
456        request: Request<SearchRequest>,
457    ) -> Result<Response<SearchResponse>, Status> {
458        let token = get_token(&request);
459        let req = request.into_inner();
460        let namespace = if req.namespace.is_empty() {
461            "default"
462        } else {
463            &req.namespace
464        };
465
466        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
467            return Err(Status::permission_denied(e));
468        }
469
470        let store = self.get_store(namespace)?;
471
472        match store.hybrid_search(&req.query, req.limit as usize, 0).await {
473            Ok(results) => {
474                let grpc_results = results
475                    .into_iter()
476                    .enumerate()
477                    .map(|(idx, (uri, score))| SearchResult {
478                        node_id: idx as u32,
479                        score,
480                        content: uri.clone(),
481                        uri,
482                    })
483                    .collect();
484                Ok(Response::new(SearchResponse {
485                    results: grpc_results,
486                }))
487            }
488            Err(e) => Err(Status::internal(e.to_string())),
489        }
490    }
491
492    async fn resolve_id(
493        &self,
494        request: Request<ResolveRequest>,
495    ) -> Result<Response<ResolveResponse>, Status> {
496        let token = get_token(&request);
497        let req = request.into_inner();
498        let namespace = if req.namespace.is_empty() {
499            "default"
500        } else {
501            &req.namespace
502        };
503
504        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
505            return Err(Status::permission_denied(e));
506        }
507
508        let store = self.get_store(namespace)?;
509
510        let uri = store.ensure_uri(&req.content);
511
512        // Look up the URI in our mapping
513        let uri_to_id = store.uri_to_id.read().unwrap();
514        if let Some(&node_id) = uri_to_id.get(&uri) {
515            Ok(Response::new(ResolveResponse {
516                node_id,
517                found: true,
518            }))
519        } else {
520            Ok(Response::new(ResolveResponse {
521                node_id: 0,
522                found: false,
523            }))
524        }
525    }
526
527    async fn get_all_triples(
528        &self,
529        request: Request<EmptyRequest>,
530    ) -> Result<Response<TriplesResponse>, Status> {
531        let token = get_token(&request);
532        let req = request.into_inner();
533        let namespace = if req.namespace.is_empty() {
534            "default"
535        } else {
536            &req.namespace
537        };
538
539        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
540            return Err(Status::permission_denied(e));
541        }
542
543        let store = self.get_store(namespace)?;
544
545        let mut triples = Vec::new();
546
547        for quad in store.store.iter().map(|q| q.unwrap()) {
548            let s = quad.subject.to_string();
549            let p = quad.predicate.to_string();
550            let o = quad.object.to_string();
551
552            // Clean up NTriples formatting (<uri> -> uri)
553            let clean_s = if s.starts_with('<') && s.ends_with('>') {
554                s[1..s.len() - 1].to_string()
555            } else {
556                s
557            };
558            let clean_p = if p.starts_with('<') && p.ends_with('>') {
559                p[1..p.len() - 1].to_string()
560            } else {
561                p
562            };
563            let clean_o = if o.starts_with('<') && o.ends_with('>') {
564                o[1..o.len() - 1].to_string()
565            } else {
566                o
567            };
568
569            triples.push(Triple {
570                subject: clean_s,
571                predicate: clean_p,
572                object: clean_o,
573                provenance: Some(Provenance {
574                    source: "oxigraph".to_string(),
575                    timestamp: "".to_string(),
576                    method: "storage".to_string(),
577                }),
578                embedding: vec![],
579            });
580        }
581
582        Ok(Response::new(TriplesResponse { triples }))
583    }
584
585    async fn query_sparql(
586        &self,
587        request: Request<SparqlRequest>,
588    ) -> Result<Response<SparqlResponse>, Status> {
589        let token = get_token(&request);
590        let req = request.into_inner();
591        let namespace = if req.namespace.is_empty() {
592            "default"
593        } else {
594            &req.namespace
595        };
596
597        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
598            return Err(Status::permission_denied(e));
599        }
600
601        let store = self.get_store(namespace)?;
602
603        match store.query_sparql(&req.query) {
604            Ok(json) => Ok(Response::new(SparqlResponse { results_json: json })),
605            Err(e) => Err(Status::internal(e.to_string())),
606        }
607    }
608
609    async fn delete_namespace_data(
610        &self,
611        request: Request<EmptyRequest>,
612    ) -> Result<Response<DeleteResponse>, Status> {
613        let token = get_token(&request);
614        let req = request.into_inner();
615        let namespace = if req.namespace.is_empty() {
616            "default"
617        } else {
618            &req.namespace
619        };
620
621        if let Err(e) = self.auth.check(token.as_deref(), namespace, "delete") {
622            return Err(Status::permission_denied(e));
623        }
624
625        // Remove from cache
626        self.stores.remove(namespace);
627
628        // Delete directory
629        let path = Path::new(&self.storage_path).join(namespace);
630        if path.exists() {
631            std::fs::remove_dir_all(path).map_err(|e| Status::internal(e.to_string()))?;
632        }
633
634        Ok(Response::new(DeleteResponse {
635            success: true,
636            message: format!("Deleted namespace '{}'", namespace),
637        }))
638    }
639
640    async fn hybrid_search(
641        &self,
642        request: Request<HybridSearchRequest>,
643    ) -> Result<Response<SearchResponse>, Status> {
644        let token = get_token(&request);
645        let req = request.into_inner();
646        let namespace = if req.namespace.is_empty() {
647            "default"
648        } else {
649            &req.namespace
650        };
651
652        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
653            return Err(Status::permission_denied(e));
654        }
655
656        let store = self.get_store(namespace)?;
657
658        let vector_k = req.vector_k as usize;
659        let graph_depth = req.graph_depth;
660
661        let results = match SearchMode::try_from(req.mode) {
662            Ok(SearchMode::VectorOnly) | Ok(SearchMode::Hybrid) => store
663                .hybrid_search(&req.query, vector_k, graph_depth)
664                .await
665                .map_err(|e| Status::internal(format!("Hybrid search failed: {}", e)))?,
666            _ => vec![],
667        };
668
669        let grpc_results = results
670            .into_iter()
671            .enumerate()
672            .map(|(idx, (uri, score))| SearchResult {
673                node_id: idx as u32,
674                score,
675                content: uri.clone(),
676                uri,
677            })
678            .collect();
679
680        Ok(Response::new(SearchResponse {
681            results: grpc_results,
682        }))
683    }
684
685    async fn apply_reasoning(
686        &self,
687        request: Request<ReasoningRequest>,
688    ) -> Result<Response<ReasoningResponse>, Status> {
689        // Auth check (Reason permission)
690        let token = get_token(&request);
691        let req = request.into_inner();
692        let namespace = if req.namespace.is_empty() {
693            "default"
694        } else {
695            &req.namespace
696        };
697
698        if let Err(e) = self.auth.check(token.as_deref(), namespace, "reason") {
699            return Err(Status::permission_denied(e));
700        }
701
702        let store = self.get_store(namespace)?;
703
704        let strategy = match ReasoningStrategy::try_from(req.strategy) {
705            Ok(ReasoningStrategy::Rdfs) => InternalStrategy::RDFS,
706            Ok(ReasoningStrategy::Owlrl) => InternalStrategy::OWLRL,
707            _ => InternalStrategy::None,
708        };
709        let strategy_name = format!("{:?}", strategy);
710
711        let reasoner = SynapseReasoner::new(strategy);
712        let start_triples = store.store.len().unwrap_or(0);
713
714        let response = if req.materialize {
715            match reasoner.materialize(&store.store) {
716                Ok(count) => Ok(Response::new(ReasoningResponse {
717                    success: true,
718                    triples_inferred: count as u32,
719                    message: format!(
720                        "Materialized {} triples in namespace '{}'",
721                        count, namespace
722                    ),
723                })),
724                Err(e) => Err(Status::internal(e.to_string())),
725            }
726        } else {
727            match reasoner.apply(&store.store) {
728                Ok(triples) => Ok(Response::new(ReasoningResponse {
729                    success: true,
730                    triples_inferred: triples.len() as u32,
731                    message: format!(
732                        "Found {} inferred triples in namespace '{}'",
733                        triples.len(),
734                        namespace
735                    ),
736                })),
737                Err(e) => Err(Status::internal(e.to_string())),
738            }
739        };
740
741        // Audit Log
742        if let Ok(ref res) = response {
743            let inferred = res.get_ref().triples_inferred as usize;
744            self.audit.log(
745                namespace,
746                &strategy_name,
747                start_triples,
748                inferred,
749                0, // Duplicates skipped not easily tracked here without changing reasoner return signature
750                vec![], // Sample inferences
751            );
752        }
753
754        response
755    }
756}
757
758pub async fn run_mcp_stdio(
759    engine: Arc<MySemanticEngine>,
760) -> Result<(), Box<dyn std::error::Error>> {
761    let server = crate::mcp_stdio::McpStdioServer::new(engine);
762    server.run().await
763}