Skip to main content

synapse_core/
server.rs

1use dashmap::DashMap;
2use std::sync::Arc;
3use tonic::{Request, Response, Status};
4
5pub mod proto {
6    tonic::include_proto!("semantic_engine");
7}
8
9use proto::semantic_engine_server::SemanticEngine;
10use proto::*;
11
12use crate::ingest::IngestionEngine;
13use crate::reasoner::{ReasoningStrategy as InternalStrategy, SynapseReasoner};
14use crate::server::proto::{ReasoningStrategy, SearchMode};
15use crate::store::{IngestTriple, SynapseStore};
16use std::path::Path;
17
18use crate::audit::InferenceAudit;
19use crate::auth::NamespaceAuth;
20
21#[derive(Clone)]
22pub struct AuthToken(pub String);
23
24#[allow(clippy::result_large_err)]
25pub fn auth_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
26    if let Some(token) = req
27        .metadata()
28        .get("authorization")
29        .and_then(|t| t.to_str().ok())
30        .map(|s| s.trim_start_matches("Bearer ").to_string())
31    {
32        req.extensions_mut().insert(AuthToken(token));
33    }
34    Ok(req)
35}
36
37fn get_token<T>(req: &Request<T>) -> Option<String> {
38    if let Some(token) = req.extensions().get::<AuthToken>() {
39        return Some(token.0.clone());
40    }
41    req.metadata()
42        .get("authorization")
43        .and_then(|t| t.to_str().ok())
44        .map(|s| s.trim_start_matches("Bearer ").to_string())
45}
46
47#[derive(Clone)]
48pub struct MySemanticEngine {
49    pub storage_path: String,
50    pub stores: Arc<DashMap<String, Arc<SynapseStore>>>,
51    pub auth: Arc<NamespaceAuth>,
52    pub audit: Arc<InferenceAudit>,
53}
54
55impl MySemanticEngine {
56    pub fn new(storage_path: &str) -> Self {
57        let auth = Arc::new(NamespaceAuth::new());
58        auth.load_from_env();
59
60        Self {
61            storage_path: storage_path.to_string(),
62            stores: Arc::new(DashMap::new()),
63            auth,
64            audit: Arc::new(InferenceAudit::new()),
65        }
66    }
67
68    pub async fn shutdown(&self) {
69        eprintln!("Shutting down... flushing {} stores", self.stores.len());
70        for entry in self.stores.iter() {
71            let store = entry.value();
72            if let Err(e) = store.flush() {
73                eprintln!("Failed to flush store '{}': {}", entry.key(), e);
74            }
75        }
76        eprintln!("Shutdown complete.");
77    }
78
79    #[allow(clippy::result_large_err)]
80    pub fn get_store(&self, namespace: &str) -> Result<Arc<SynapseStore>, Status> {
81        if let Some(store) = self.stores.get(namespace) {
82            return Ok(store.clone());
83        }
84
85        let store = SynapseStore::open(namespace, &self.storage_path).map_err(|e| {
86            Status::internal(format!(
87                "Failed to open store for namespace '{}': {}",
88                namespace, e
89            ))
90        })?;
91
92        let store_arc = Arc::new(store);
93        self.stores.insert(namespace.to_string(), store_arc.clone());
94        Ok(store_arc)
95    }
96}
97
98#[tonic::async_trait]
99impl SemanticEngine for MySemanticEngine {
100    async fn ingest_triples(
101        &self,
102        request: Request<IngestRequest>,
103    ) -> Result<Response<IngestResponse>, Status> {
104        // Auth check (Write permission)
105        let token = get_token(&request);
106        let req = request.into_inner();
107        let namespace = if req.namespace.is_empty() {
108            "default"
109        } else {
110            &req.namespace
111        };
112
113        if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
114            return Err(Status::permission_denied(e));
115        }
116
117        let store = self.get_store(namespace)?;
118
119        // Log provenance for audit
120        let timestamp = chrono::Utc::now().to_rfc3339();
121        let triple_count = req.triples.len();
122        let mut sources: Vec<String> = Vec::new();
123
124        let triples: Vec<IngestTriple> = req
125            .triples
126            .into_iter()
127            .map(|t| {
128                // Capture provenance sources for logging
129                if let Some(ref prov) = t.provenance {
130                    if !prov.source.is_empty() && !sources.contains(&prov.source) {
131                        sources.push(prov.source.clone());
132                    }
133                }
134                IngestTriple {
135                    subject: t.subject,
136                    predicate: t.predicate,
137                    object: t.object,
138                    provenance: t.provenance.map(|p| crate::store::Provenance {
139                        source: p.source,
140                        timestamp: p.timestamp,
141                        method: p.method,
142                    }),
143                }
144            })
145            .collect();
146
147        match store.ingest_triples(triples).await {
148            Ok((added, _)) => {
149                // Log ingestion for audit trail
150                eprintln!(
151                    "INGEST [{timestamp}] namespace={namespace} triples={triple_count} added={added} sources={:?}",
152                    sources
153                );
154                Ok(Response::new(IngestResponse {
155                    nodes_added: added,
156                    edges_added: added,
157                }))
158            }
159            Err(e) => Err(Status::internal(e.to_string())),
160        }
161    }
162
163    async fn ingest_file(
164        &self,
165        request: Request<IngestFileRequest>,
166    ) -> Result<Response<IngestResponse>, Status> {
167        // Auth check (Write permission) - previously missing? or just implicit?
168        // Note: The original code didn't check auth for ingest_file!
169        // Adding it now for consistency as we are touching auth.
170        let token = get_token(&request);
171        let req = request.into_inner();
172        let namespace = if req.namespace.is_empty() {
173            "default"
174        } else {
175            &req.namespace
176        };
177
178        if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
179            return Err(Status::permission_denied(e));
180        }
181        let store = self.get_store(namespace)?;
182
183        let engine = IngestionEngine::new(store);
184        let path = Path::new(&req.file_path);
185
186        match engine.ingest_file(path, namespace).await {
187            Ok(count) => Ok(Response::new(IngestResponse {
188                nodes_added: count,
189                edges_added: count,
190            })),
191            Err(e) => Err(Status::internal(e.to_string())),
192        }
193    }
194
195    async fn get_neighbors(
196        &self,
197        request: Request<NodeRequest>,
198    ) -> Result<Response<NeighborResponse>, Status> {
199        let token = get_token(&request);
200        let req = request.into_inner();
201        let namespace = if req.namespace.is_empty() {
202            "default"
203        } else {
204            &req.namespace
205        };
206
207        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
208            return Err(Status::permission_denied(e));
209        }
210
211        let store = self.get_store(namespace)?;
212
213        let direction = if req.direction.is_empty() {
214            "outgoing"
215        } else {
216            &req.direction
217        };
218        let edge_filter = if req.edge_filter.is_empty() {
219            None
220        } else {
221            Some(req.edge_filter.as_str())
222        };
223        let node_type_filter = if req.node_type_filter.is_empty() {
224            None
225        } else {
226            Some(req.node_type_filter.as_str())
227        };
228        let max_depth = if req.depth == 0 {
229            1
230        } else {
231            req.depth as usize
232        };
233        let limit_per_layer = if req.limit_per_layer == 0 {
234            usize::MAX
235        } else {
236            req.limit_per_layer as usize
237        };
238
239        let mut neighbors = Vec::new();
240        let mut visited = std::collections::HashSet::new();
241        let mut current_frontier = Vec::new();
242
243        // Start with the initial node
244        if let Some(start_uri) = store.get_uri(req.node_id) {
245            current_frontier.push(start_uri.clone());
246            visited.insert(start_uri);
247        }
248
249        // BFS traversal up to max_depth
250        for current_depth in 1..=max_depth {
251            let mut next_frontier = Vec::new();
252            let mut layer_count = 0;
253            let base_score = 1.0 / current_depth as f32; // Path scoring: closer = higher
254
255            for uri in &current_frontier {
256                if layer_count >= limit_per_layer {
257                    break;
258                }
259
260                // Query outgoing edges (URI as subject)
261                if direction == "outgoing" || direction == "both" {
262                    if let Ok(subj) = oxigraph::model::NamedNodeRef::new(uri) {
263                        for quad in
264                            store
265                                .store
266                                .quads_for_pattern(Some(subj.into()), None, None, None)
267                        {
268                            if layer_count >= limit_per_layer {
269                                break;
270                            }
271                            if let Ok(q) = quad {
272                                let pred = q.predicate.to_string();
273                                // Apply edge filter if specified
274                                if let Some(filter) = edge_filter {
275                                    if !pred.contains(filter) {
276                                        continue;
277                                    }
278                                }
279                                let obj_term = q.object;
280                                let obj_uri = obj_term.to_string();
281
282                                // Node Type Filter Logic
283                                if let Some(type_filter) = node_type_filter {
284                                    let passed =
285                                        if let oxigraph::model::Term::NamedNode(ref n) = obj_term {
286                                            let rdf_type = oxigraph::model::NamedNodeRef::new(
287                                                "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
288                                            )
289                                            .unwrap();
290                                            if let Ok(target_type) =
291                                                oxigraph::model::NamedNodeRef::new(type_filter)
292                                            {
293                                                store
294                                                    .store
295                                                    .quads_for_pattern(
296                                                        Some(n.into()),
297                                                        Some(rdf_type),
298                                                        Some(target_type.into()),
299                                                        None,
300                                                    )
301                                                    .next()
302                                                    .is_some()
303                                            } else {
304                                                false
305                                            }
306                                        } else {
307                                            false
308                                        };
309                                    if !passed {
310                                        continue;
311                                    }
312                                }
313
314                                let clean_uri = match &obj_term {
315                                    oxigraph::model::Term::NamedNode(n) => n.as_str(),
316                                    _ => &obj_uri,
317                                };
318
319                                    // Always add to neighbors if not already in neighbors list to avoid duplicates there
320                                    // But we must allow revisiting nodes for graph expansion if we want to find paths?
321                                    // BFS typically avoids cycles by checking visited.
322
323                                    // NOTE: visited set prevents processing same node twice in BFS.
324                                    // If we reach a node that was already visited in a previous layer (or this layer), skip it.
325                                if !visited.contains(&obj_uri) {
326                                    visited.insert(obj_uri.clone());
327                                    let obj_id = store.get_or_create_id(&obj_uri);
328
329                                    let mut neighbor_score = base_score;
330                                    if req.scoring_strategy == "degree" {
331                                        let degree = store.get_degree(clean_uri);
332                                            neighbor_score /= (degree as f32 + 1.0).ln().max(0.1);
333                                    }
334
335                                    neighbors.push(Neighbor {
336                                        node_id: obj_id,
337                                        edge_type: pred,
338                                        uri: obj_uri.clone(), // This is the N-Triples formatted string for display
339                                        direction: "outgoing".to_string(),
340                                        depth: current_depth as u32,
341                                        score: neighbor_score,
342                                    });
343                                    // Use clean_uri for next frontier to ensure we query with raw URI, not <uri>
344                                    next_frontier.push(clean_uri.to_string());
345                                    layer_count += 1;
346                                }
347                            }
348                        }
349                    }
350                }
351
352                // Query incoming edges (URI as object)
353                if direction == "incoming" || direction == "both" {
354                    if let Ok(obj) = oxigraph::model::NamedNodeRef::new(uri) {
355                        for quad in
356                            store
357                                .store
358                                .quads_for_pattern(None, None, Some(obj.into()), None)
359                        {
360                            if layer_count >= limit_per_layer {
361                                break;
362                            }
363                            if let Ok(q) = quad {
364                                let pred = q.predicate.to_string();
365                                // Apply edge filter if specified
366                                if let Some(filter) = edge_filter {
367                                    if !pred.contains(filter) {
368                                        continue;
369                                    }
370                                }
371                                let subj_term = q.subject;
372                                let subj_uri = subj_term.to_string();
373
374                                // Node Type Filter Logic
375                                if let Some(type_filter) = node_type_filter {
376                                    let passed = if let oxigraph::model::Subject::NamedNode(ref n) =
377                                        subj_term
378                                    {
379                                        let rdf_type = oxigraph::model::NamedNodeRef::new(
380                                            "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
381                                        )
382                                        .unwrap();
383                                        if let Ok(target_type) =
384                                            oxigraph::model::NamedNodeRef::new(type_filter)
385                                        {
386                                            store
387                                                .store
388                                                .quads_for_pattern(
389                                                    Some(n.into()),
390                                                    Some(rdf_type),
391                                                    Some(target_type.into()),
392                                                    None,
393                                                )
394                                                .next()
395                                                .is_some()
396                                        } else {
397                                            false
398                                        }
399                                    } else {
400                                        false
401                                    };
402                                    if !passed {
403                                        continue;
404                                    }
405                                }
406
407                                let clean_uri = match &subj_term {
408                                    oxigraph::model::Subject::NamedNode(n) => n.as_str(),
409                                    _ => &subj_uri,
410                                };
411
412                                if !visited.contains(&subj_uri) {
413                                    visited.insert(subj_uri.clone());
414                                    let subj_id = store.get_or_create_id(&subj_uri);
415
416                                    let mut neighbor_score = base_score;
417                                    if req.scoring_strategy == "degree" {
418                                        let degree = store.get_degree(clean_uri);
419                                        // Penalize super nodes
420                                        neighbor_score /= (degree as f32 + 1.0).ln().max(0.1);
421                                    }
422
423                                    neighbors.push(Neighbor {
424                                        node_id: subj_id,
425                                        edge_type: pred,
426                                        uri: subj_uri.clone(),
427                                        direction: "incoming".to_string(),
428                                        depth: current_depth as u32,
429                                        score: neighbor_score,
430                                    });
431                                    // Use clean_uri for next frontier
432                                    next_frontier.push(clean_uri.to_string());
433                                    layer_count += 1;
434                                }
435                            }
436                        }
437                    }
438                }
439            }
440
441            current_frontier = next_frontier;
442            if current_frontier.is_empty() {
443                break;
444            }
445        }
446
447        // Sort by score (highest first)
448        neighbors.sort_by(|a, b| {
449            b.score
450                .partial_cmp(&a.score)
451                .unwrap_or(std::cmp::Ordering::Equal)
452        });
453
454        Ok(Response::new(NeighborResponse { neighbors }))
455    }
456
457    async fn search(
458        &self,
459        request: Request<SearchRequest>,
460    ) -> Result<Response<SearchResponse>, Status> {
461        let token = get_token(&request);
462        let req = request.into_inner();
463        let namespace = if req.namespace.is_empty() {
464            "default"
465        } else {
466            &req.namespace
467        };
468
469        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
470            return Err(Status::permission_denied(e));
471        }
472
473        let store = self.get_store(namespace)?;
474
475        match store.hybrid_search(&req.query, req.limit as usize, 0).await {
476            Ok(results) => {
477                let grpc_results = results
478                    .into_iter()
479                    .enumerate()
480                    .map(|(idx, (uri, score))| SearchResult {
481                        node_id: idx as u32,
482                        score,
483                        content: uri.clone(),
484                        uri,
485                    })
486                    .collect();
487                Ok(Response::new(SearchResponse {
488                    results: grpc_results,
489                }))
490            }
491            Err(e) => Err(Status::internal(e.to_string())),
492        }
493    }
494
495    async fn resolve_id(
496        &self,
497        request: Request<ResolveRequest>,
498    ) -> Result<Response<ResolveResponse>, Status> {
499        let token = get_token(&request);
500        let req = request.into_inner();
501        let namespace = if req.namespace.is_empty() {
502            "default"
503        } else {
504            &req.namespace
505        };
506
507        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
508            return Err(Status::permission_denied(e));
509        }
510
511        let store = self.get_store(namespace)?;
512
513        let uri = store.ensure_uri(&req.content);
514
515        // Look up the URI in our mapping
516        let uri_to_id = store.uri_to_id.read().unwrap();
517        if let Some(&node_id) = uri_to_id.get(&uri) {
518            Ok(Response::new(ResolveResponse {
519                node_id,
520                found: true,
521            }))
522        } else {
523            Ok(Response::new(ResolveResponse {
524                node_id: 0,
525                found: false,
526            }))
527        }
528    }
529
530    async fn get_all_triples(
531        &self,
532        request: Request<EmptyRequest>,
533    ) -> Result<Response<TriplesResponse>, Status> {
534        let token = get_token(&request);
535        let req = request.into_inner();
536        let namespace = if req.namespace.is_empty() {
537            "default"
538        } else {
539            &req.namespace
540        };
541
542        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
543            return Err(Status::permission_denied(e));
544        }
545
546        let store = self.get_store(namespace)?;
547
548        let mut triples = Vec::new();
549
550        for quad in store.store.iter().map(|q| q.unwrap()) {
551            let s = quad.subject.to_string();
552            let p = quad.predicate.to_string();
553            let o = quad.object.to_string();
554
555            // Clean up NTriples formatting (<uri> -> uri)
556            let clean_s = if s.starts_with('<') && s.ends_with('>') {
557                s[1..s.len() - 1].to_string()
558            } else {
559                s
560            };
561            let clean_p = if p.starts_with('<') && p.ends_with('>') {
562                p[1..p.len() - 1].to_string()
563            } else {
564                p
565            };
566            let clean_o = if o.starts_with('<') && o.ends_with('>') {
567                o[1..o.len() - 1].to_string()
568            } else {
569                o
570            };
571
572            triples.push(Triple {
573                subject: clean_s,
574                predicate: clean_p,
575                object: clean_o,
576                provenance: Some(Provenance {
577                    source: "oxigraph".to_string(),
578                    timestamp: "".to_string(),
579                    method: "storage".to_string(),
580                }),
581                embedding: vec![],
582            });
583        }
584
585        Ok(Response::new(TriplesResponse { triples }))
586    }
587
588    async fn query_sparql(
589        &self,
590        request: Request<SparqlRequest>,
591    ) -> Result<Response<SparqlResponse>, Status> {
592        let token = get_token(&request);
593        let req = request.into_inner();
594        let namespace = if req.namespace.is_empty() {
595            "default"
596        } else {
597            &req.namespace
598        };
599
600        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
601            return Err(Status::permission_denied(e));
602        }
603
604        let store = self.get_store(namespace)?;
605
606        match store.query_sparql(&req.query) {
607            Ok(json) => Ok(Response::new(SparqlResponse { results_json: json })),
608            Err(e) => Err(Status::internal(e.to_string())),
609        }
610    }
611
612    async fn delete_namespace_data(
613        &self,
614        request: Request<EmptyRequest>,
615    ) -> Result<Response<DeleteResponse>, Status> {
616        let token = get_token(&request);
617        let req = request.into_inner();
618        let namespace = if req.namespace.is_empty() {
619            "default"
620        } else {
621            &req.namespace
622        };
623
624        if let Err(e) = self.auth.check(token.as_deref(), namespace, "delete") {
625            return Err(Status::permission_denied(e));
626        }
627
628        // Remove from cache
629        self.stores.remove(namespace);
630
631        // Delete directory
632        let path = Path::new(&self.storage_path).join(namespace);
633        if path.exists() {
634            std::fs::remove_dir_all(path).map_err(|e| Status::internal(e.to_string()))?;
635        }
636
637        Ok(Response::new(DeleteResponse {
638            success: true,
639            message: format!("Deleted namespace '{}'", namespace),
640        }))
641    }
642
643    async fn hybrid_search(
644        &self,
645        request: Request<HybridSearchRequest>,
646    ) -> Result<Response<SearchResponse>, Status> {
647        let token = get_token(&request);
648        let req = request.into_inner();
649        let namespace = if req.namespace.is_empty() {
650            "default"
651        } else {
652            &req.namespace
653        };
654
655        if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
656            return Err(Status::permission_denied(e));
657        }
658
659        let store = self.get_store(namespace)?;
660
661        let vector_k = req.vector_k as usize;
662        let graph_depth = req.graph_depth;
663
664        let results = match SearchMode::try_from(req.mode) {
665            Ok(SearchMode::VectorOnly) | Ok(SearchMode::Hybrid) => store
666                .hybrid_search(&req.query, vector_k, graph_depth)
667                .await
668                .map_err(|e| Status::internal(format!("Hybrid search failed: {}", e)))?,
669            _ => vec![],
670        };
671
672        let grpc_results = results
673            .into_iter()
674            .enumerate()
675            .map(|(idx, (uri, score))| SearchResult {
676                node_id: idx as u32,
677                score,
678                content: uri.clone(),
679                uri,
680            })
681            .collect();
682
683        Ok(Response::new(SearchResponse {
684            results: grpc_results,
685        }))
686    }
687
688    async fn apply_reasoning(
689        &self,
690        request: Request<ReasoningRequest>,
691    ) -> Result<Response<ReasoningResponse>, Status> {
692        // Auth check (Reason permission)
693        let token = get_token(&request);
694        let req = request.into_inner();
695        let namespace = if req.namespace.is_empty() {
696            "default"
697        } else {
698            &req.namespace
699        };
700
701        if let Err(e) = self.auth.check(token.as_deref(), namespace, "reason") {
702            return Err(Status::permission_denied(e));
703        }
704
705        let store = self.get_store(namespace)?;
706
707        let strategy = match ReasoningStrategy::try_from(req.strategy) {
708            Ok(ReasoningStrategy::Rdfs) => InternalStrategy::RDFS,
709            Ok(ReasoningStrategy::Owlrl) => InternalStrategy::OWLRL,
710            _ => InternalStrategy::None,
711        };
712        let strategy_name = format!("{:?}", strategy);
713
714        let reasoner = SynapseReasoner::new(strategy);
715        let start_triples = store.store.len().unwrap_or(0);
716
717        let response = if req.materialize {
718            match reasoner.materialize(&store.store) {
719                Ok(count) => Ok(Response::new(ReasoningResponse {
720                    success: true,
721                    triples_inferred: count as u32,
722                    message: format!(
723                        "Materialized {} triples in namespace '{}'",
724                        count, namespace
725                    ),
726                })),
727                Err(e) => Err(Status::internal(e.to_string())),
728            }
729        } else {
730            match reasoner.apply(&store.store) {
731                Ok(triples) => Ok(Response::new(ReasoningResponse {
732                    success: true,
733                    triples_inferred: triples.len() as u32,
734                    message: format!(
735                        "Found {} inferred triples in namespace '{}'",
736                        triples.len(),
737                        namespace
738                    ),
739                })),
740                Err(e) => Err(Status::internal(e.to_string())),
741            }
742        };
743
744        // Audit Log
745        if let Ok(ref res) = response {
746            let inferred = res.get_ref().triples_inferred as usize;
747            self.audit.log(
748                namespace,
749                &strategy_name,
750                start_triples,
751                inferred,
752                0, // Duplicates skipped not easily tracked here without changing reasoner return signature
753                vec![], // Sample inferences
754            );
755        }
756
757        response
758    }
759}
760
761pub async fn run_mcp_stdio(
762    engine: Arc<MySemanticEngine>,
763) -> Result<(), Box<dyn std::error::Error>> {
764    let server = crate::mcp_stdio::McpStdioServer::new(engine);
765    server.run().await
766}