1use dashmap::DashMap;
2use std::sync::Arc;
3use tonic::{Request, Response, Status};
4
5pub mod proto {
6 tonic::include_proto!("semantic_engine");
7}
8
9use proto::semantic_engine_server::SemanticEngine;
10use proto::*;
11
12use crate::ingest::IngestionEngine;
13use crate::reasoner::{ReasoningStrategy as InternalStrategy, SynapseReasoner};
14use crate::server::proto::{ReasoningStrategy, SearchMode};
15use crate::store::{IngestTriple, SynapseStore};
16use std::path::Path;
17
18use crate::audit::InferenceAudit;
19use crate::auth::NamespaceAuth;
20
21#[derive(Clone)]
22pub struct AuthToken(pub String);
23
24#[allow(clippy::result_large_err)]
25pub fn auth_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
26 if let Some(token) = req
27 .metadata()
28 .get("authorization")
29 .and_then(|t| t.to_str().ok())
30 .map(|s| s.trim_start_matches("Bearer ").to_string())
31 {
32 req.extensions_mut().insert(AuthToken(token));
33 }
34 Ok(req)
35}
36
37fn get_token<T>(req: &Request<T>) -> Option<String> {
38 if let Some(token) = req.extensions().get::<AuthToken>() {
39 return Some(token.0.clone());
40 }
41 req.metadata()
42 .get("authorization")
43 .and_then(|t| t.to_str().ok())
44 .map(|s| s.trim_start_matches("Bearer ").to_string())
45}
46
47#[derive(Clone)]
48pub struct MySemanticEngine {
49 pub storage_path: String,
50 pub stores: Arc<DashMap<String, Arc<SynapseStore>>>,
51 pub auth: Arc<NamespaceAuth>,
52 pub audit: Arc<InferenceAudit>,
53}
54
55impl MySemanticEngine {
56 pub fn new(storage_path: &str) -> Self {
57 let auth = Arc::new(NamespaceAuth::new());
58 auth.load_from_env();
59
60 Self {
61 storage_path: storage_path.to_string(),
62 stores: Arc::new(DashMap::new()),
63 auth,
64 audit: Arc::new(InferenceAudit::new()),
65 }
66 }
67
68 pub async fn shutdown(&self) {
69 eprintln!("Shutting down... flushing {} stores", self.stores.len());
70 for entry in self.stores.iter() {
71 let store = entry.value();
72 if let Err(e) = store.flush() {
73 eprintln!("Failed to flush store '{}': {}", entry.key(), e);
74 }
75 }
76 eprintln!("Shutdown complete.");
77 }
78
79 #[allow(clippy::result_large_err)]
80 pub fn get_store(&self, namespace: &str) -> Result<Arc<SynapseStore>, Status> {
81 if let Some(store) = self.stores.get(namespace) {
82 return Ok(store.clone());
83 }
84
85 let store = SynapseStore::open(namespace, &self.storage_path).map_err(|e| {
86 Status::internal(format!(
87 "Failed to open store for namespace '{}': {}",
88 namespace, e
89 ))
90 })?;
91
92 let store_arc = Arc::new(store);
93 self.stores.insert(namespace.to_string(), store_arc.clone());
94 Ok(store_arc)
95 }
96}
97
98#[tonic::async_trait]
99impl SemanticEngine for MySemanticEngine {
100 async fn ingest_triples(
101 &self,
102 request: Request<IngestRequest>,
103 ) -> Result<Response<IngestResponse>, Status> {
104 let token = get_token(&request);
106 let req = request.into_inner();
107 let namespace = if req.namespace.is_empty() {
108 "default"
109 } else {
110 &req.namespace
111 };
112
113 if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
114 return Err(Status::permission_denied(e));
115 }
116
117 let store = self.get_store(namespace)?;
118
119 let timestamp = chrono::Utc::now().to_rfc3339();
121 let triple_count = req.triples.len();
122 let mut sources: Vec<String> = Vec::new();
123
124 let triples: Vec<IngestTriple> = req
125 .triples
126 .into_iter()
127 .map(|t| {
128 if let Some(ref prov) = t.provenance {
130 if !prov.source.is_empty() && !sources.contains(&prov.source) {
131 sources.push(prov.source.clone());
132 }
133 }
134 IngestTriple {
135 subject: t.subject,
136 predicate: t.predicate,
137 object: t.object,
138 provenance: t.provenance.map(|p| crate::store::Provenance {
139 source: p.source,
140 timestamp: p.timestamp,
141 method: p.method,
142 }),
143 }
144 })
145 .collect();
146
147 match store.ingest_triples(triples).await {
148 Ok((added, _)) => {
149 eprintln!(
151 "INGEST [{timestamp}] namespace={namespace} triples={triple_count} added={added} sources={:?}",
152 sources
153 );
154 Ok(Response::new(IngestResponse {
155 nodes_added: added,
156 edges_added: added,
157 }))
158 }
159 Err(e) => Err(Status::internal(e.to_string())),
160 }
161 }
162
163 async fn ingest_file(
164 &self,
165 request: Request<IngestFileRequest>,
166 ) -> Result<Response<IngestResponse>, Status> {
167 let token = get_token(&request);
171 let req = request.into_inner();
172 let namespace = if req.namespace.is_empty() {
173 "default"
174 } else {
175 &req.namespace
176 };
177
178 if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
179 return Err(Status::permission_denied(e));
180 }
181 let store = self.get_store(namespace)?;
182
183 let engine = IngestionEngine::new(store);
184 let path = Path::new(&req.file_path);
185
186 match engine.ingest_file(path, namespace).await {
187 Ok(count) => Ok(Response::new(IngestResponse {
188 nodes_added: count,
189 edges_added: count,
190 })),
191 Err(e) => Err(Status::internal(e.to_string())),
192 }
193 }
194
195 async fn get_neighbors(
196 &self,
197 request: Request<NodeRequest>,
198 ) -> Result<Response<NeighborResponse>, Status> {
199 let token = get_token(&request);
200 let req = request.into_inner();
201 let namespace = if req.namespace.is_empty() {
202 "default"
203 } else {
204 &req.namespace
205 };
206
207 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
208 return Err(Status::permission_denied(e));
209 }
210
211 let store = self.get_store(namespace)?;
212
213 let direction = if req.direction.is_empty() {
214 "outgoing"
215 } else {
216 &req.direction
217 };
218 let edge_filter = if req.edge_filter.is_empty() {
219 None
220 } else {
221 Some(req.edge_filter.as_str())
222 };
223 let node_type_filter = if req.node_type_filter.is_empty() {
224 None
225 } else {
226 Some(req.node_type_filter.as_str())
227 };
228 let max_depth = if req.depth == 0 {
229 1
230 } else {
231 req.depth as usize
232 };
233 let limit_per_layer = if req.limit_per_layer == 0 {
234 usize::MAX
235 } else {
236 req.limit_per_layer as usize
237 };
238
239 let mut neighbors = Vec::new();
240 let mut visited = std::collections::HashSet::new();
241 let mut current_frontier = Vec::new();
242
243 if let Some(start_uri) = store.get_uri(req.node_id) {
245 current_frontier.push(start_uri.clone());
246 visited.insert(start_uri);
247 }
248
249 for current_depth in 1..=max_depth {
251 let mut next_frontier = Vec::new();
252 let mut layer_count = 0;
253 let base_score = 1.0 / current_depth as f32; for uri in ¤t_frontier {
256 if layer_count >= limit_per_layer {
257 break;
258 }
259
260 if direction == "outgoing" || direction == "both" {
262 if let Ok(subj) = oxigraph::model::NamedNodeRef::new(uri) {
263 for quad in
264 store
265 .store
266 .quads_for_pattern(Some(subj.into()), None, None, None)
267 {
268 if layer_count >= limit_per_layer {
269 break;
270 }
271 if let Ok(q) = quad {
272 let pred = q.predicate.to_string();
273 if let Some(filter) = edge_filter {
275 if !pred.contains(filter) {
276 continue;
277 }
278 }
279 let obj_term = q.object;
280 let obj_uri = obj_term.to_string();
281
282 if let Some(type_filter) = node_type_filter {
284 let passed =
285 if let oxigraph::model::Term::NamedNode(ref n) = obj_term {
286 let rdf_type = oxigraph::model::NamedNodeRef::new(
287 "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
288 )
289 .unwrap();
290 if let Ok(target_type) =
291 oxigraph::model::NamedNodeRef::new(type_filter)
292 {
293 store
294 .store
295 .quads_for_pattern(
296 Some(n.into()),
297 Some(rdf_type),
298 Some(target_type.into()),
299 None,
300 )
301 .next()
302 .is_some()
303 } else {
304 false
305 }
306 } else {
307 false
308 };
309 if !passed {
310 continue;
311 }
312 }
313
314 let clean_uri = match &obj_term {
315 oxigraph::model::Term::NamedNode(n) => n.as_str(),
316 _ => &obj_uri,
317 };
318
319 if !visited.contains(&obj_uri) {
326 visited.insert(obj_uri.clone());
327 let obj_id = store.get_or_create_id(&obj_uri);
328
329 let mut neighbor_score = base_score;
330 if req.scoring_strategy == "degree" {
331 let degree = store.get_degree(clean_uri);
332 neighbor_score /= (degree as f32 + 1.0).ln().max(0.1);
333 }
334
335 neighbors.push(Neighbor {
336 node_id: obj_id,
337 edge_type: pred,
338 uri: obj_uri.clone(), direction: "outgoing".to_string(),
340 depth: current_depth as u32,
341 score: neighbor_score,
342 });
343 next_frontier.push(clean_uri.to_string());
345 layer_count += 1;
346 }
347 }
348 }
349 }
350 }
351
352 if direction == "incoming" || direction == "both" {
354 if let Ok(obj) = oxigraph::model::NamedNodeRef::new(uri) {
355 for quad in
356 store
357 .store
358 .quads_for_pattern(None, None, Some(obj.into()), None)
359 {
360 if layer_count >= limit_per_layer {
361 break;
362 }
363 if let Ok(q) = quad {
364 let pred = q.predicate.to_string();
365 if let Some(filter) = edge_filter {
367 if !pred.contains(filter) {
368 continue;
369 }
370 }
371 let subj_term = q.subject;
372 let subj_uri = subj_term.to_string();
373
374 if let Some(type_filter) = node_type_filter {
376 let passed = if let oxigraph::model::Subject::NamedNode(ref n) =
377 subj_term
378 {
379 let rdf_type = oxigraph::model::NamedNodeRef::new(
380 "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
381 )
382 .unwrap();
383 if let Ok(target_type) =
384 oxigraph::model::NamedNodeRef::new(type_filter)
385 {
386 store
387 .store
388 .quads_for_pattern(
389 Some(n.into()),
390 Some(rdf_type),
391 Some(target_type.into()),
392 None,
393 )
394 .next()
395 .is_some()
396 } else {
397 false
398 }
399 } else {
400 false
401 };
402 if !passed {
403 continue;
404 }
405 }
406
407 let clean_uri = match &subj_term {
408 oxigraph::model::Subject::NamedNode(n) => n.as_str(),
409 _ => &subj_uri,
410 };
411
412 if !visited.contains(&subj_uri) {
413 visited.insert(subj_uri.clone());
414 let subj_id = store.get_or_create_id(&subj_uri);
415
416 let mut neighbor_score = base_score;
417 if req.scoring_strategy == "degree" {
418 let degree = store.get_degree(clean_uri);
419 neighbor_score /= (degree as f32 + 1.0).ln().max(0.1);
421 }
422
423 neighbors.push(Neighbor {
424 node_id: subj_id,
425 edge_type: pred,
426 uri: subj_uri.clone(),
427 direction: "incoming".to_string(),
428 depth: current_depth as u32,
429 score: neighbor_score,
430 });
431 next_frontier.push(clean_uri.to_string());
433 layer_count += 1;
434 }
435 }
436 }
437 }
438 }
439 }
440
441 current_frontier = next_frontier;
442 if current_frontier.is_empty() {
443 break;
444 }
445 }
446
447 neighbors.sort_by(|a, b| {
449 b.score
450 .partial_cmp(&a.score)
451 .unwrap_or(std::cmp::Ordering::Equal)
452 });
453
454 Ok(Response::new(NeighborResponse { neighbors }))
455 }
456
457 async fn search(
458 &self,
459 request: Request<SearchRequest>,
460 ) -> Result<Response<SearchResponse>, Status> {
461 let token = get_token(&request);
462 let req = request.into_inner();
463 let namespace = if req.namespace.is_empty() {
464 "default"
465 } else {
466 &req.namespace
467 };
468
469 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
470 return Err(Status::permission_denied(e));
471 }
472
473 let store = self.get_store(namespace)?;
474
475 match store.hybrid_search(&req.query, req.limit as usize, 0).await {
476 Ok(results) => {
477 let grpc_results = results
478 .into_iter()
479 .enumerate()
480 .map(|(idx, (uri, score))| SearchResult {
481 node_id: idx as u32,
482 score,
483 content: uri.clone(),
484 uri,
485 })
486 .collect();
487 Ok(Response::new(SearchResponse {
488 results: grpc_results,
489 }))
490 }
491 Err(e) => Err(Status::internal(e.to_string())),
492 }
493 }
494
495 async fn resolve_id(
496 &self,
497 request: Request<ResolveRequest>,
498 ) -> Result<Response<ResolveResponse>, Status> {
499 let token = get_token(&request);
500 let req = request.into_inner();
501 let namespace = if req.namespace.is_empty() {
502 "default"
503 } else {
504 &req.namespace
505 };
506
507 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
508 return Err(Status::permission_denied(e));
509 }
510
511 let store = self.get_store(namespace)?;
512
513 let uri = store.ensure_uri(&req.content);
514
515 let uri_to_id = store.uri_to_id.read().unwrap();
517 if let Some(&node_id) = uri_to_id.get(&uri) {
518 Ok(Response::new(ResolveResponse {
519 node_id,
520 found: true,
521 }))
522 } else {
523 Ok(Response::new(ResolveResponse {
524 node_id: 0,
525 found: false,
526 }))
527 }
528 }
529
530 async fn get_all_triples(
531 &self,
532 request: Request<EmptyRequest>,
533 ) -> Result<Response<TriplesResponse>, Status> {
534 let token = get_token(&request);
535 let req = request.into_inner();
536 let namespace = if req.namespace.is_empty() {
537 "default"
538 } else {
539 &req.namespace
540 };
541
542 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
543 return Err(Status::permission_denied(e));
544 }
545
546 let store = self.get_store(namespace)?;
547
548 let mut triples = Vec::new();
549
550 for quad in store.store.iter().map(|q| q.unwrap()) {
551 let s = quad.subject.to_string();
552 let p = quad.predicate.to_string();
553 let o = quad.object.to_string();
554
555 let clean_s = if s.starts_with('<') && s.ends_with('>') {
557 s[1..s.len() - 1].to_string()
558 } else {
559 s
560 };
561 let clean_p = if p.starts_with('<') && p.ends_with('>') {
562 p[1..p.len() - 1].to_string()
563 } else {
564 p
565 };
566 let clean_o = if o.starts_with('<') && o.ends_with('>') {
567 o[1..o.len() - 1].to_string()
568 } else {
569 o
570 };
571
572 triples.push(Triple {
573 subject: clean_s,
574 predicate: clean_p,
575 object: clean_o,
576 provenance: Some(Provenance {
577 source: "oxigraph".to_string(),
578 timestamp: "".to_string(),
579 method: "storage".to_string(),
580 }),
581 embedding: vec![],
582 });
583 }
584
585 Ok(Response::new(TriplesResponse { triples }))
586 }
587
588 async fn query_sparql(
589 &self,
590 request: Request<SparqlRequest>,
591 ) -> Result<Response<SparqlResponse>, Status> {
592 let token = get_token(&request);
593 let req = request.into_inner();
594 let namespace = if req.namespace.is_empty() {
595 "default"
596 } else {
597 &req.namespace
598 };
599
600 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
601 return Err(Status::permission_denied(e));
602 }
603
604 let store = self.get_store(namespace)?;
605
606 match store.query_sparql(&req.query) {
607 Ok(json) => Ok(Response::new(SparqlResponse { results_json: json })),
608 Err(e) => Err(Status::internal(e.to_string())),
609 }
610 }
611
612 async fn delete_namespace_data(
613 &self,
614 request: Request<EmptyRequest>,
615 ) -> Result<Response<DeleteResponse>, Status> {
616 let token = get_token(&request);
617 let req = request.into_inner();
618 let namespace = if req.namespace.is_empty() {
619 "default"
620 } else {
621 &req.namespace
622 };
623
624 if let Err(e) = self.auth.check(token.as_deref(), namespace, "delete") {
625 return Err(Status::permission_denied(e));
626 }
627
628 self.stores.remove(namespace);
630
631 let path = Path::new(&self.storage_path).join(namespace);
633 if path.exists() {
634 std::fs::remove_dir_all(path).map_err(|e| Status::internal(e.to_string()))?;
635 }
636
637 Ok(Response::new(DeleteResponse {
638 success: true,
639 message: format!("Deleted namespace '{}'", namespace),
640 }))
641 }
642
643 async fn hybrid_search(
644 &self,
645 request: Request<HybridSearchRequest>,
646 ) -> Result<Response<SearchResponse>, Status> {
647 let token = get_token(&request);
648 let req = request.into_inner();
649 let namespace = if req.namespace.is_empty() {
650 "default"
651 } else {
652 &req.namespace
653 };
654
655 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
656 return Err(Status::permission_denied(e));
657 }
658
659 let store = self.get_store(namespace)?;
660
661 let vector_k = req.vector_k as usize;
662 let graph_depth = req.graph_depth;
663
664 let results = match SearchMode::try_from(req.mode) {
665 Ok(SearchMode::VectorOnly) | Ok(SearchMode::Hybrid) => store
666 .hybrid_search(&req.query, vector_k, graph_depth)
667 .await
668 .map_err(|e| Status::internal(format!("Hybrid search failed: {}", e)))?,
669 _ => vec![],
670 };
671
672 let grpc_results = results
673 .into_iter()
674 .enumerate()
675 .map(|(idx, (uri, score))| SearchResult {
676 node_id: idx as u32,
677 score,
678 content: uri.clone(),
679 uri,
680 })
681 .collect();
682
683 Ok(Response::new(SearchResponse {
684 results: grpc_results,
685 }))
686 }
687
688 async fn apply_reasoning(
689 &self,
690 request: Request<ReasoningRequest>,
691 ) -> Result<Response<ReasoningResponse>, Status> {
692 let token = get_token(&request);
694 let req = request.into_inner();
695 let namespace = if req.namespace.is_empty() {
696 "default"
697 } else {
698 &req.namespace
699 };
700
701 if let Err(e) = self.auth.check(token.as_deref(), namespace, "reason") {
702 return Err(Status::permission_denied(e));
703 }
704
705 let store = self.get_store(namespace)?;
706
707 let strategy = match ReasoningStrategy::try_from(req.strategy) {
708 Ok(ReasoningStrategy::Rdfs) => InternalStrategy::RDFS,
709 Ok(ReasoningStrategy::Owlrl) => InternalStrategy::OWLRL,
710 _ => InternalStrategy::None,
711 };
712 let strategy_name = format!("{:?}", strategy);
713
714 let reasoner = SynapseReasoner::new(strategy);
715 let start_triples = store.store.len().unwrap_or(0);
716
717 let response = if req.materialize {
718 match reasoner.materialize(&store.store) {
719 Ok(count) => Ok(Response::new(ReasoningResponse {
720 success: true,
721 triples_inferred: count as u32,
722 message: format!(
723 "Materialized {} triples in namespace '{}'",
724 count, namespace
725 ),
726 })),
727 Err(e) => Err(Status::internal(e.to_string())),
728 }
729 } else {
730 match reasoner.apply(&store.store) {
731 Ok(triples) => Ok(Response::new(ReasoningResponse {
732 success: true,
733 triples_inferred: triples.len() as u32,
734 message: format!(
735 "Found {} inferred triples in namespace '{}'",
736 triples.len(),
737 namespace
738 ),
739 })),
740 Err(e) => Err(Status::internal(e.to_string())),
741 }
742 };
743
744 if let Ok(ref res) = response {
746 let inferred = res.get_ref().triples_inferred as usize;
747 self.audit.log(
748 namespace,
749 &strategy_name,
750 start_triples,
751 inferred,
752 0, vec![], );
755 }
756
757 response
758 }
759}
760
761pub async fn run_mcp_stdio(
762 engine: Arc<MySemanticEngine>,
763) -> Result<(), Box<dyn std::error::Error>> {
764 let server = crate::mcp_stdio::McpStdioServer::new(engine);
765 server.run().await
766}