1use dashmap::DashMap;
2use std::sync::Arc;
3use tonic::{Request, Response, Status};
4
5pub mod proto {
6 tonic::include_proto!("semantic_engine");
7}
8
9use proto::semantic_engine_server::SemanticEngine;
10use proto::*;
11
12use crate::ingest::IngestionEngine;
13use crate::reasoner::{ReasoningStrategy as InternalStrategy, SynapseReasoner};
14use crate::server::proto::{ReasoningStrategy, SearchMode};
15use crate::store::{IngestTriple, SynapseStore};
16use std::path::Path;
17
18use crate::audit::InferenceAudit;
19use crate::auth::NamespaceAuth;
20
21#[derive(Clone)]
22pub struct AuthToken(pub String);
23
24#[allow(clippy::result_large_err)]
25pub fn auth_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
26 if let Some(token) = req
27 .metadata()
28 .get("authorization")
29 .and_then(|t| t.to_str().ok())
30 .map(|s| s.trim_start_matches("Bearer ").to_string())
31 {
32 req.extensions_mut().insert(AuthToken(token));
33 }
34 Ok(req)
35}
36
37fn get_token<T>(req: &Request<T>) -> Option<String> {
38 if let Some(token) = req.extensions().get::<AuthToken>() {
39 return Some(token.0.clone());
40 }
41 req.metadata()
42 .get("authorization")
43 .and_then(|t| t.to_str().ok())
44 .map(|s| s.trim_start_matches("Bearer ").to_string())
45}
46
47#[derive(Clone)]
48pub struct MySemanticEngine {
49 pub storage_path: String,
50 pub stores: Arc<DashMap<String, Arc<SynapseStore>>>,
51 pub auth: Arc<NamespaceAuth>,
52 pub audit: Arc<InferenceAudit>,
53}
54
55impl MySemanticEngine {
56 pub fn new(storage_path: &str) -> Self {
57 let auth = Arc::new(NamespaceAuth::new());
58 auth.load_from_env();
59
60 Self {
61 storage_path: storage_path.to_string(),
62 stores: Arc::new(DashMap::new()),
63 auth,
64 audit: Arc::new(InferenceAudit::new()),
65 }
66 }
67
68 pub async fn shutdown(&self) {
69 eprintln!("Shutting down... flushing {} stores", self.stores.len());
70 for entry in self.stores.iter() {
71 let store = entry.value();
72 if let Err(e) = store.flush() {
73 eprintln!("Failed to flush store '{}': {}", entry.key(), e);
74 }
75 }
76 eprintln!("Shutdown complete.");
77 }
78
79 #[allow(clippy::result_large_err)]
80 pub fn get_store(&self, namespace: &str) -> Result<Arc<SynapseStore>, Status> {
81 if let Some(store) = self.stores.get(namespace) {
82 return Ok(store.clone());
83 }
84
85 let store = SynapseStore::open(namespace, &self.storage_path).map_err(|e| {
86 Status::internal(format!(
87 "Failed to open store for namespace '{}': {}",
88 namespace, e
89 ))
90 })?;
91
92 let store_arc = Arc::new(store);
93 self.stores.insert(namespace.to_string(), store_arc.clone());
94 Ok(store_arc)
95 }
96}
97
98#[tonic::async_trait]
99impl SemanticEngine for MySemanticEngine {
100 async fn ingest_triples(
101 &self,
102 request: Request<IngestRequest>,
103 ) -> Result<Response<IngestResponse>, Status> {
104 let token = get_token(&request);
106 let req = request.into_inner();
107 let namespace = if req.namespace.is_empty() {
108 "default"
109 } else {
110 &req.namespace
111 };
112
113 if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
114 return Err(Status::permission_denied(e));
115 }
116
117 let store = self.get_store(namespace)?;
118
119 let timestamp = chrono::Utc::now().to_rfc3339();
121 let triple_count = req.triples.len();
122 let mut sources: Vec<String> = Vec::new();
123
124 let triples: Vec<IngestTriple> = req
125 .triples
126 .into_iter()
127 .map(|t| {
128 if let Some(ref prov) = t.provenance {
130 if !prov.source.is_empty() && !sources.contains(&prov.source) {
131 sources.push(prov.source.clone());
132 }
133 }
134 IngestTriple {
135 subject: t.subject,
136 predicate: t.predicate,
137 object: t.object,
138 provenance: t.provenance.map(|p| crate::store::Provenance {
139 source: p.source,
140 timestamp: p.timestamp,
141 method: p.method,
142 }),
143 }
144 })
145 .collect();
146
147 match store.ingest_triples(triples).await {
148 Ok((added, _)) => {
149 eprintln!(
151 "INGEST [{timestamp}] namespace={namespace} triples={triple_count} added={added} sources={:?}",
152 sources
153 );
154 Ok(Response::new(IngestResponse {
155 nodes_added: added,
156 edges_added: added,
157 }))
158 }
159 Err(e) => Err(Status::internal(e.to_string())),
160 }
161 }
162
163 async fn ingest_file(
164 &self,
165 request: Request<IngestFileRequest>,
166 ) -> Result<Response<IngestResponse>, Status> {
167 let token = get_token(&request);
171 let req = request.into_inner();
172 let namespace = if req.namespace.is_empty() {
173 "default"
174 } else {
175 &req.namespace
176 };
177
178 if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
179 return Err(Status::permission_denied(e));
180 }
181 let store = self.get_store(namespace)?;
182
183 let engine = IngestionEngine::new(store);
184 let path = Path::new(&req.file_path);
185
186 match engine.ingest_file(path, namespace).await {
187 Ok(count) => Ok(Response::new(IngestResponse {
188 nodes_added: count,
189 edges_added: count,
190 })),
191 Err(e) => Err(Status::internal(e.to_string())),
192 }
193 }
194
195 async fn get_neighbors(
196 &self,
197 request: Request<NodeRequest>,
198 ) -> Result<Response<NeighborResponse>, Status> {
199 let token = get_token(&request);
200 let req = request.into_inner();
201 let namespace = if req.namespace.is_empty() {
202 "default"
203 } else {
204 &req.namespace
205 };
206
207 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
208 return Err(Status::permission_denied(e));
209 }
210
211 let store = self.get_store(namespace)?;
212
213 let direction = if req.direction.is_empty() {
214 "outgoing"
215 } else {
216 &req.direction
217 };
218 let edge_filter = if req.edge_filter.is_empty() {
219 None
220 } else {
221 Some(req.edge_filter.as_str())
222 };
223 let node_type_filter = if req.node_type_filter.is_empty() {
224 None
225 } else {
226 Some(req.node_type_filter.as_str())
227 };
228 let max_depth = if req.depth == 0 {
229 1
230 } else {
231 req.depth as usize
232 };
233 let limit_per_layer = if req.limit_per_layer == 0 {
234 usize::MAX
235 } else {
236 req.limit_per_layer as usize
237 };
238
239 let mut neighbors = Vec::new();
240 let mut visited = std::collections::HashSet::new();
241 let mut current_frontier = Vec::new();
242
243 if let Some(start_uri) = store.get_uri(req.node_id) {
245 current_frontier.push(start_uri.clone());
246 visited.insert(start_uri);
247 }
248
249 for current_depth in 1..=max_depth {
251 let mut next_frontier = Vec::new();
252 let mut layer_count = 0;
253 let base_score = 1.0 / current_depth as f32; for uri in ¤t_frontier {
256 if layer_count >= limit_per_layer {
257 break;
258 }
259
260 if direction == "outgoing" || direction == "both" {
262 if let Ok(subj) = oxigraph::model::NamedNodeRef::new(uri) {
263 for quad in
264 store
265 .store
266 .quads_for_pattern(Some(subj.into()), None, None, None)
267 {
268 if layer_count >= limit_per_layer {
269 break;
270 }
271 if let Ok(q) = quad {
272 let pred = q.predicate.to_string();
273 if let Some(filter) = edge_filter {
275 if !pred.contains(filter) {
276 continue;
277 }
278 }
279 let obj_term = q.object;
280 let obj_uri = obj_term.to_string();
281
282 if let Some(type_filter) = node_type_filter {
284 let passed =
285 if let oxigraph::model::Term::NamedNode(ref n) = obj_term {
286 let rdf_type = oxigraph::model::NamedNodeRef::new(
287 "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
288 )
289 .unwrap();
290 if let Ok(target_type) =
291 oxigraph::model::NamedNodeRef::new(type_filter)
292 {
293 store
294 .store
295 .quads_for_pattern(
296 Some(n.into()),
297 Some(rdf_type),
298 Some(target_type.into()),
299 None,
300 )
301 .next()
302 .is_some()
303 } else {
304 false
305 }
306 } else {
307 false
308 };
309 if !passed {
310 continue;
311 }
312 }
313
314 let clean_uri = match &obj_term {
315 oxigraph::model::Term::NamedNode(n) => n.as_str(),
316 _ => &obj_uri,
317 };
318
319 if !visited.contains(&obj_uri) {
320 visited.insert(obj_uri.clone());
321 let obj_id = store.get_or_create_id(&obj_uri);
322
323 let mut neighbor_score = base_score;
324 if req.scoring_strategy == "degree" {
325 let degree = store.get_degree(clean_uri);
326 if degree > 1 {
328 neighbor_score /= (degree as f32).ln().max(1.0);
329 }
330 }
331
332 neighbors.push(Neighbor {
333 node_id: obj_id,
334 edge_type: pred,
335 uri: obj_uri.clone(),
336 direction: "outgoing".to_string(),
337 depth: current_depth as u32,
338 score: neighbor_score,
339 });
340 next_frontier.push(obj_uri);
341 layer_count += 1;
342 }
343 }
344 }
345 }
346 }
347
348 if direction == "incoming" || direction == "both" {
350 if let Ok(obj) = oxigraph::model::NamedNodeRef::new(uri) {
351 for quad in
352 store
353 .store
354 .quads_for_pattern(None, None, Some(obj.into()), None)
355 {
356 if layer_count >= limit_per_layer {
357 break;
358 }
359 if let Ok(q) = quad {
360 let pred = q.predicate.to_string();
361 if let Some(filter) = edge_filter {
363 if !pred.contains(filter) {
364 continue;
365 }
366 }
367 let subj_term = q.subject;
368 let subj_uri = subj_term.to_string();
369
370 if let Some(type_filter) = node_type_filter {
372 let passed = if let oxigraph::model::Subject::NamedNode(ref n) =
373 subj_term
374 {
375 let rdf_type = oxigraph::model::NamedNodeRef::new(
376 "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
377 )
378 .unwrap();
379 if let Ok(target_type) =
380 oxigraph::model::NamedNodeRef::new(type_filter)
381 {
382 store
383 .store
384 .quads_for_pattern(
385 Some(n.into()),
386 Some(rdf_type),
387 Some(target_type.into()),
388 None,
389 )
390 .next()
391 .is_some()
392 } else {
393 false
394 }
395 } else {
396 false
397 };
398 if !passed {
399 continue;
400 }
401 }
402
403 let clean_uri = match &subj_term {
404 oxigraph::model::Subject::NamedNode(n) => n.as_str(),
405 _ => &subj_uri,
406 };
407
408 if !visited.contains(&subj_uri) {
409 visited.insert(subj_uri.clone());
410 let subj_id = store.get_or_create_id(&subj_uri);
411
412 let mut neighbor_score = base_score;
413 if req.scoring_strategy == "degree" {
414 let degree = store.get_degree(clean_uri);
415 if degree > 1 {
417 neighbor_score /= (degree as f32).ln().max(1.0);
418 }
419 }
420
421 neighbors.push(Neighbor {
422 node_id: subj_id,
423 edge_type: pred,
424 uri: subj_uri.clone(),
425 direction: "incoming".to_string(),
426 depth: current_depth as u32,
427 score: neighbor_score,
428 });
429 next_frontier.push(subj_uri);
430 layer_count += 1;
431 }
432 }
433 }
434 }
435 }
436 }
437
438 current_frontier = next_frontier;
439 if current_frontier.is_empty() {
440 break;
441 }
442 }
443
444 neighbors.sort_by(|a, b| {
446 b.score
447 .partial_cmp(&a.score)
448 .unwrap_or(std::cmp::Ordering::Equal)
449 });
450
451 Ok(Response::new(NeighborResponse { neighbors }))
452 }
453
454 async fn search(
455 &self,
456 request: Request<SearchRequest>,
457 ) -> Result<Response<SearchResponse>, Status> {
458 let token = get_token(&request);
459 let req = request.into_inner();
460 let namespace = if req.namespace.is_empty() {
461 "default"
462 } else {
463 &req.namespace
464 };
465
466 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
467 return Err(Status::permission_denied(e));
468 }
469
470 let store = self.get_store(namespace)?;
471
472 match store.hybrid_search(&req.query, req.limit as usize, 0).await {
473 Ok(results) => {
474 let grpc_results = results
475 .into_iter()
476 .enumerate()
477 .map(|(idx, (uri, score))| SearchResult {
478 node_id: idx as u32,
479 score,
480 content: uri.clone(),
481 uri,
482 })
483 .collect();
484 Ok(Response::new(SearchResponse {
485 results: grpc_results,
486 }))
487 }
488 Err(e) => Err(Status::internal(e.to_string())),
489 }
490 }
491
492 async fn resolve_id(
493 &self,
494 request: Request<ResolveRequest>,
495 ) -> Result<Response<ResolveResponse>, Status> {
496 let token = get_token(&request);
497 let req = request.into_inner();
498 let namespace = if req.namespace.is_empty() {
499 "default"
500 } else {
501 &req.namespace
502 };
503
504 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
505 return Err(Status::permission_denied(e));
506 }
507
508 let store = self.get_store(namespace)?;
509
510 let uri = store.ensure_uri(&req.content);
511
512 let uri_to_id = store.uri_to_id.read().unwrap();
514 if let Some(&node_id) = uri_to_id.get(&uri) {
515 Ok(Response::new(ResolveResponse {
516 node_id,
517 found: true,
518 }))
519 } else {
520 Ok(Response::new(ResolveResponse {
521 node_id: 0,
522 found: false,
523 }))
524 }
525 }
526
527 async fn get_all_triples(
528 &self,
529 request: Request<EmptyRequest>,
530 ) -> Result<Response<TriplesResponse>, Status> {
531 let token = get_token(&request);
532 let req = request.into_inner();
533 let namespace = if req.namespace.is_empty() {
534 "default"
535 } else {
536 &req.namespace
537 };
538
539 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
540 return Err(Status::permission_denied(e));
541 }
542
543 let store = self.get_store(namespace)?;
544
545 let mut triples = Vec::new();
546
547 for quad in store.store.iter().map(|q| q.unwrap()) {
548 let s = quad.subject.to_string();
549 let p = quad.predicate.to_string();
550 let o = quad.object.to_string();
551
552 let clean_s = if s.starts_with('<') && s.ends_with('>') {
554 s[1..s.len() - 1].to_string()
555 } else {
556 s
557 };
558 let clean_p = if p.starts_with('<') && p.ends_with('>') {
559 p[1..p.len() - 1].to_string()
560 } else {
561 p
562 };
563 let clean_o = if o.starts_with('<') && o.ends_with('>') {
564 o[1..o.len() - 1].to_string()
565 } else {
566 o
567 };
568
569 triples.push(Triple {
570 subject: clean_s,
571 predicate: clean_p,
572 object: clean_o,
573 provenance: Some(Provenance {
574 source: "oxigraph".to_string(),
575 timestamp: "".to_string(),
576 method: "storage".to_string(),
577 }),
578 embedding: vec![],
579 });
580 }
581
582 Ok(Response::new(TriplesResponse { triples }))
583 }
584
585 async fn query_sparql(
586 &self,
587 request: Request<SparqlRequest>,
588 ) -> Result<Response<SparqlResponse>, Status> {
589 let token = get_token(&request);
590 let req = request.into_inner();
591 let namespace = if req.namespace.is_empty() {
592 "default"
593 } else {
594 &req.namespace
595 };
596
597 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
598 return Err(Status::permission_denied(e));
599 }
600
601 let store = self.get_store(namespace)?;
602
603 match store.query_sparql(&req.query) {
604 Ok(json) => Ok(Response::new(SparqlResponse { results_json: json })),
605 Err(e) => Err(Status::internal(e.to_string())),
606 }
607 }
608
609 async fn delete_namespace_data(
610 &self,
611 request: Request<EmptyRequest>,
612 ) -> Result<Response<DeleteResponse>, Status> {
613 let token = get_token(&request);
614 let req = request.into_inner();
615 let namespace = if req.namespace.is_empty() {
616 "default"
617 } else {
618 &req.namespace
619 };
620
621 if let Err(e) = self.auth.check(token.as_deref(), namespace, "delete") {
622 return Err(Status::permission_denied(e));
623 }
624
625 self.stores.remove(namespace);
627
628 let path = Path::new(&self.storage_path).join(namespace);
630 if path.exists() {
631 std::fs::remove_dir_all(path).map_err(|e| Status::internal(e.to_string()))?;
632 }
633
634 Ok(Response::new(DeleteResponse {
635 success: true,
636 message: format!("Deleted namespace '{}'", namespace),
637 }))
638 }
639
640 async fn hybrid_search(
641 &self,
642 request: Request<HybridSearchRequest>,
643 ) -> Result<Response<SearchResponse>, Status> {
644 let token = get_token(&request);
645 let req = request.into_inner();
646 let namespace = if req.namespace.is_empty() {
647 "default"
648 } else {
649 &req.namespace
650 };
651
652 if let Err(e) = self.auth.check(token.as_deref(), namespace, "read") {
653 return Err(Status::permission_denied(e));
654 }
655
656 let store = self.get_store(namespace)?;
657
658 let vector_k = req.vector_k as usize;
659 let graph_depth = req.graph_depth;
660
661 let results = match SearchMode::try_from(req.mode) {
662 Ok(SearchMode::VectorOnly) | Ok(SearchMode::Hybrid) => store
663 .hybrid_search(&req.query, vector_k, graph_depth)
664 .await
665 .map_err(|e| Status::internal(format!("Hybrid search failed: {}", e)))?,
666 _ => vec![],
667 };
668
669 let grpc_results = results
670 .into_iter()
671 .enumerate()
672 .map(|(idx, (uri, score))| SearchResult {
673 node_id: idx as u32,
674 score,
675 content: uri.clone(),
676 uri,
677 })
678 .collect();
679
680 Ok(Response::new(SearchResponse {
681 results: grpc_results,
682 }))
683 }
684
685 async fn apply_reasoning(
686 &self,
687 request: Request<ReasoningRequest>,
688 ) -> Result<Response<ReasoningResponse>, Status> {
689 let token = get_token(&request);
691 let req = request.into_inner();
692 let namespace = if req.namespace.is_empty() {
693 "default"
694 } else {
695 &req.namespace
696 };
697
698 if let Err(e) = self.auth.check(token.as_deref(), namespace, "reason") {
699 return Err(Status::permission_denied(e));
700 }
701
702 let store = self.get_store(namespace)?;
703
704 let strategy = match ReasoningStrategy::try_from(req.strategy) {
705 Ok(ReasoningStrategy::Rdfs) => InternalStrategy::RDFS,
706 Ok(ReasoningStrategy::Owlrl) => InternalStrategy::OWLRL,
707 _ => InternalStrategy::None,
708 };
709 let strategy_name = format!("{:?}", strategy);
710
711 let reasoner = SynapseReasoner::new(strategy);
712 let start_triples = store.store.len().unwrap_or(0);
713
714 let response = if req.materialize {
715 match reasoner.materialize(&store.store) {
716 Ok(count) => Ok(Response::new(ReasoningResponse {
717 success: true,
718 triples_inferred: count as u32,
719 message: format!(
720 "Materialized {} triples in namespace '{}'",
721 count, namespace
722 ),
723 })),
724 Err(e) => Err(Status::internal(e.to_string())),
725 }
726 } else {
727 match reasoner.apply(&store.store) {
728 Ok(triples) => Ok(Response::new(ReasoningResponse {
729 success: true,
730 triples_inferred: triples.len() as u32,
731 message: format!(
732 "Found {} inferred triples in namespace '{}'",
733 triples.len(),
734 namespace
735 ),
736 })),
737 Err(e) => Err(Status::internal(e.to_string())),
738 }
739 };
740
741 if let Ok(ref res) = response {
743 let inferred = res.get_ref().triples_inferred as usize;
744 self.audit.log(
745 namespace,
746 &strategy_name,
747 start_triples,
748 inferred,
749 0, vec![], );
752 }
753
754 response
755 }
756}
757
758pub async fn run_mcp_stdio(
759 engine: Arc<MySemanticEngine>,
760) -> Result<(), Box<dyn std::error::Error>> {
761 let server = crate::mcp_stdio::McpStdioServer::new(engine);
762 server.run().await
763}