1use dashmap::DashMap;
2use std::sync::Arc;
3use tonic::{Request, Response, Status};
4
5pub mod proto {
6 tonic::include_proto!("semantic_engine");
7}
8
9use proto::semantic_engine_server::SemanticEngine;
10use proto::*;
11
12use crate::ingest::IngestionEngine;
13use crate::reasoner::{ReasoningStrategy as InternalStrategy, SynapseReasoner};
14use crate::server::proto::{ReasoningStrategy, SearchMode};
15use crate::store::SynapseStore;
16use std::path::Path;
17
18use crate::audit::InferenceAudit;
19use crate::auth::NamespaceAuth;
20
21#[derive(Clone)]
22pub struct AuthToken(pub String);
23
24pub fn auth_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
25 if let Some(token) = req.metadata().get("authorization")
26 .and_then(|t| t.to_str().ok())
27 .map(|s| s.trim_start_matches("Bearer ").to_string()) {
28 req.extensions_mut().insert(AuthToken(token));
29 }
30 Ok(req)
31}
32
33fn get_token<T>(req: &Request<T>) -> Option<String> {
34 if let Some(token) = req.extensions().get::<AuthToken>() {
35 return Some(token.0.clone());
36 }
37 req.metadata().get("authorization")
38 .and_then(|t| t.to_str().ok())
39 .map(|s| s.trim_start_matches("Bearer ").to_string())
40}
41
42pub struct MySemanticEngine {
43 pub storage_path: String,
44 pub stores: DashMap<String, Arc<SynapseStore>>,
45 pub auth: Arc<NamespaceAuth>,
46 pub audit: Arc<InferenceAudit>,
47}
48
49impl MySemanticEngine {
50 pub fn new(storage_path: &str) -> Self {
51 let auth = Arc::new(NamespaceAuth::new());
52 auth.load_from_env();
53
54 Self {
55 storage_path: storage_path.to_string(),
56 stores: DashMap::new(),
57 auth,
58 audit: Arc::new(InferenceAudit::new()),
59 }
60 }
61
62 pub fn get_store(&self, namespace: &str) -> Result<Arc<SynapseStore>, Status> {
63 if let Some(store) = self.stores.get(namespace) {
64 return Ok(store.clone());
65 }
66
67 let store = SynapseStore::open(namespace, &self.storage_path).map_err(|e| {
68 Status::internal(format!(
69 "Failed to open store for namespace '{}': {}",
70 namespace, e
71 ))
72 })?;
73
74 let store_arc = Arc::new(store);
75 self.stores.insert(namespace.to_string(), store_arc.clone());
76 Ok(store_arc)
77 }
78}
79
80#[tonic::async_trait]
81impl SemanticEngine for MySemanticEngine {
82 async fn ingest_triples(
83 &self,
84 request: Request<IngestRequest>,
85 ) -> Result<Response<IngestResponse>, Status> {
86 let token = get_token(&request);
88 let req = request.into_inner();
89 let namespace = if req.namespace.is_empty() {
90 "default"
91 } else {
92 &req.namespace
93 };
94
95 if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
96 return Err(Status::permission_denied(e));
97 }
98
99 let store = self.get_store(namespace)?;
100
101 let timestamp = chrono::Utc::now().to_rfc3339();
103 let triple_count = req.triples.len();
104 let mut sources: Vec<String> = Vec::new();
105
106 let triples: Vec<(String, String, String)> = req
107 .triples
108 .into_iter()
109 .map(|t| {
110 if let Some(ref prov) = t.provenance {
112 if !prov.source.is_empty() && !sources.contains(&prov.source) {
113 sources.push(prov.source.clone());
114 }
115 }
116 (t.subject, t.predicate, t.object)
117 })
118 .collect();
119
120 match store.ingest_triples(triples).await {
121 Ok((added, _)) => {
122 eprintln!(
124 "INGEST [{timestamp}] namespace={namespace} triples={triple_count} added={added} sources={:?}",
125 sources
126 );
127 Ok(Response::new(IngestResponse {
128 nodes_added: added,
129 edges_added: added,
130 }))
131 }
132 Err(e) => Err(Status::internal(e.to_string())),
133 }
134 }
135
136 async fn ingest_file(
137 &self,
138 request: Request<IngestFileRequest>,
139 ) -> Result<Response<IngestResponse>, Status> {
140 let token = get_token(&request);
144 let req = request.into_inner();
145 let namespace = if req.namespace.is_empty() { "default" } else { &req.namespace };
146
147 if let Err(e) = self.auth.check(token.as_deref(), namespace, "write") {
148 return Err(Status::permission_denied(e));
149 }
150 let store = self.get_store(namespace)?;
151
152 let engine = IngestionEngine::new(store);
153 let path = Path::new(&req.file_path);
154
155 match engine.ingest_file(path, namespace).await {
156 Ok(count) => Ok(Response::new(IngestResponse {
157 nodes_added: count,
158 edges_added: count,
159 })),
160 Err(e) => Err(Status::internal(e.to_string())),
161 }
162 }
163
164 async fn get_neighbors(
165 &self,
166 request: Request<NodeRequest>,
167 ) -> Result<Response<NeighborResponse>, Status> {
168 let req = request.into_inner();
169 let namespace = if req.namespace.is_empty() {
170 "default"
171 } else {
172 &req.namespace
173 };
174 let store = self.get_store(namespace)?;
175
176 let direction = if req.direction.is_empty() {
177 "outgoing"
178 } else {
179 &req.direction
180 };
181 let edge_filter = if req.edge_filter.is_empty() {
182 None
183 } else {
184 Some(req.edge_filter.as_str())
185 };
186 let max_depth = if req.depth == 0 {
187 1
188 } else {
189 req.depth as usize
190 };
191 let limit_per_layer = if req.limit_per_layer == 0 {
192 usize::MAX
193 } else {
194 req.limit_per_layer as usize
195 };
196
197 let mut neighbors = Vec::new();
198 let mut visited = std::collections::HashSet::new();
199 let mut current_frontier = Vec::new();
200
201 if let Some(start_uri) = store.get_uri(req.node_id) {
203 current_frontier.push(start_uri.clone());
204 visited.insert(start_uri);
205 }
206
207 for current_depth in 1..=max_depth {
209 let mut next_frontier = Vec::new();
210 let mut layer_count = 0;
211 let score = 1.0 / current_depth as f32; for uri in ¤t_frontier {
214 if layer_count >= limit_per_layer {
215 break;
216 }
217
218 if direction == "outgoing" || direction == "both" {
220 if let Ok(subj) = oxigraph::model::NamedNodeRef::new(uri) {
221 for quad in
222 store
223 .store
224 .quads_for_pattern(Some(subj.into()), None, None, None)
225 {
226 if layer_count >= limit_per_layer {
227 break;
228 }
229 if let Ok(q) = quad {
230 let pred = q.predicate.to_string();
231 if let Some(filter) = edge_filter {
233 if !pred.contains(filter) {
234 continue;
235 }
236 }
237 let obj_uri = q.object.to_string();
238 if !visited.contains(&obj_uri) {
239 visited.insert(obj_uri.clone());
240 let obj_id = store.get_or_create_id(&obj_uri);
241 neighbors.push(Neighbor {
242 node_id: obj_id,
243 edge_type: pred,
244 uri: obj_uri.clone(),
245 direction: "outgoing".to_string(),
246 depth: current_depth as u32,
247 score,
248 });
249 next_frontier.push(obj_uri);
250 layer_count += 1;
251 }
252 }
253 }
254 }
255 }
256
257 if direction == "incoming" || direction == "both" {
259 if let Ok(obj) = oxigraph::model::NamedNodeRef::new(uri) {
260 for quad in
261 store
262 .store
263 .quads_for_pattern(None, None, Some(obj.into()), None)
264 {
265 if layer_count >= limit_per_layer {
266 break;
267 }
268 if let Ok(q) = quad {
269 let pred = q.predicate.to_string();
270 if let Some(filter) = edge_filter {
272 if !pred.contains(filter) {
273 continue;
274 }
275 }
276 let subj_uri = q.subject.to_string();
277 if !visited.contains(&subj_uri) {
278 visited.insert(subj_uri.clone());
279 let subj_id = store.get_or_create_id(&subj_uri);
280 neighbors.push(Neighbor {
281 node_id: subj_id,
282 edge_type: pred,
283 uri: subj_uri.clone(),
284 direction: "incoming".to_string(),
285 depth: current_depth as u32,
286 score,
287 });
288 next_frontier.push(subj_uri);
289 layer_count += 1;
290 }
291 }
292 }
293 }
294 }
295 }
296
297 current_frontier = next_frontier;
298 if current_frontier.is_empty() {
299 break;
300 }
301 }
302
303 neighbors.sort_by(|a, b| {
305 b.score
306 .partial_cmp(&a.score)
307 .unwrap_or(std::cmp::Ordering::Equal)
308 });
309
310 Ok(Response::new(NeighborResponse { neighbors }))
311 }
312
313 async fn search(
314 &self,
315 request: Request<SearchRequest>,
316 ) -> Result<Response<SearchResponse>, Status> {
317 let req = request.into_inner();
318 let namespace = if req.namespace.is_empty() {
319 "default"
320 } else {
321 &req.namespace
322 };
323 let store = self.get_store(namespace)?;
324
325 match store.hybrid_search(&req.query, req.limit as usize, 0).await {
326 Ok(results) => {
327 let grpc_results = results
328 .into_iter()
329 .enumerate()
330 .map(|(idx, (uri, score))| SearchResult {
331 node_id: idx as u32,
332 score,
333 content: uri.clone(),
334 uri,
335 })
336 .collect();
337 Ok(Response::new(SearchResponse {
338 results: grpc_results,
339 }))
340 }
341 Err(e) => Err(Status::internal(e.to_string())),
342 }
343 }
344
345 async fn resolve_id(
346 &self,
347 request: Request<ResolveRequest>,
348 ) -> Result<Response<ResolveResponse>, Status> {
349 let req = request.into_inner();
350 let namespace = if req.namespace.is_empty() {
351 "default"
352 } else {
353 &req.namespace
354 };
355 let store = self.get_store(namespace)?;
356
357 let uri_to_id = store.uri_to_id.read().unwrap();
359 if let Some(&node_id) = uri_to_id.get(&req.content) {
360 Ok(Response::new(ResolveResponse {
361 node_id,
362 found: true,
363 }))
364 } else {
365 Ok(Response::new(ResolveResponse {
366 node_id: 0,
367 found: false,
368 }))
369 }
370 }
371
372 async fn get_all_triples(
373 &self,
374 request: Request<EmptyRequest>,
375 ) -> Result<Response<TriplesResponse>, Status> {
376 let req = request.into_inner();
377 let namespace = if req.namespace.is_empty() {
378 "default"
379 } else {
380 &req.namespace
381 };
382 let store = self.get_store(namespace)?;
383
384 let mut triples = Vec::new();
385
386 for quad in store.store.iter().map(|q| q.unwrap()) {
387 triples.push(Triple {
388 subject: quad.subject.to_string(),
389 predicate: quad.predicate.to_string(),
390 object: quad.object.to_string(),
391 provenance: Some(Provenance {
392 source: "oxigraph".to_string(),
393 timestamp: "".to_string(),
394 method: "storage".to_string(),
395 }),
396 embedding: vec![],
397 });
398 }
399
400 Ok(Response::new(TriplesResponse { triples }))
401 }
402
403 async fn query_sparql(
404 &self,
405 request: Request<SparqlRequest>,
406 ) -> Result<Response<SparqlResponse>, Status> {
407 let req = request.into_inner();
408 let namespace = if req.namespace.is_empty() {
409 "default"
410 } else {
411 &req.namespace
412 };
413 let store = self.get_store(namespace)?;
414
415 match store.query_sparql(&req.query) {
416 Ok(json) => Ok(Response::new(SparqlResponse { results_json: json })),
417 Err(e) => Err(Status::internal(e.to_string())),
418 }
419 }
420
421 async fn delete_namespace_data(
422 &self,
423 request: Request<EmptyRequest>,
424 ) -> Result<Response<DeleteResponse>, Status> {
425 let req = request.into_inner();
426 let namespace = if req.namespace.is_empty() {
427 "default"
428 } else {
429 &req.namespace
430 };
431
432 self.stores.remove(namespace);
434
435 let path = Path::new(&self.storage_path).join(namespace);
437 if path.exists() {
438 std::fs::remove_dir_all(path).map_err(|e| Status::internal(e.to_string()))?;
439 }
440
441 Ok(Response::new(DeleteResponse {
442 success: true,
443 message: format!("Deleted namespace '{}'", namespace),
444 }))
445 }
446
447 async fn hybrid_search(
448 &self,
449 request: Request<HybridSearchRequest>,
450 ) -> Result<Response<SearchResponse>, Status> {
451 let req = request.into_inner();
452 let namespace = if req.namespace.is_empty() {
453 "default"
454 } else {
455 &req.namespace
456 };
457 let store = self.get_store(namespace)?;
458
459 let vector_k = req.vector_k as usize;
460 let graph_depth = req.graph_depth;
461
462 let results = match SearchMode::try_from(req.mode) {
463 Ok(SearchMode::VectorOnly) | Ok(SearchMode::Hybrid) => store
464 .hybrid_search(&req.query, vector_k, graph_depth)
465 .await
466 .map_err(|e| Status::internal(format!("Hybrid search failed: {}", e)))?,
467 _ => vec![],
468 };
469
470 let grpc_results = results
471 .into_iter()
472 .enumerate()
473 .map(|(idx, (uri, score))| SearchResult {
474 node_id: idx as u32,
475 score,
476 content: uri.clone(),
477 uri,
478 })
479 .collect();
480
481 Ok(Response::new(SearchResponse {
482 results: grpc_results,
483 }))
484 }
485
486 async fn apply_reasoning(
487 &self,
488 request: Request<ReasoningRequest>,
489 ) -> Result<Response<ReasoningResponse>, Status> {
490 let token = get_token(&request);
492 let req = request.into_inner();
493 let namespace = if req.namespace.is_empty() {
494 "default"
495 } else {
496 &req.namespace
497 };
498
499 if let Err(e) = self.auth.check(token.as_deref(), namespace, "reason") {
500 return Err(Status::permission_denied(e));
501 }
502
503 let store = self.get_store(namespace)?;
504
505 let strategy = match ReasoningStrategy::try_from(req.strategy) {
506 Ok(ReasoningStrategy::Rdfs) => InternalStrategy::RDFS,
507 Ok(ReasoningStrategy::Owlrl) => InternalStrategy::OWLRL,
508 _ => InternalStrategy::None,
509 };
510 let strategy_name = format!("{:?}", strategy);
511
512 let reasoner = SynapseReasoner::new(strategy);
513 let start_triples = store.store.len().unwrap_or(0);
514
515 let response = if req.materialize {
516 match reasoner.materialize(&store.store) {
517 Ok(count) => Ok(Response::new(ReasoningResponse {
518 success: true,
519 triples_inferred: count as u32,
520 message: format!(
521 "Materialized {} triples in namespace '{}'",
522 count, namespace
523 ),
524 })),
525 Err(e) => Err(Status::internal(e.to_string())),
526 }
527 } else {
528 match reasoner.apply(&store.store) {
529 Ok(triples) => Ok(Response::new(ReasoningResponse {
530 success: true,
531 triples_inferred: triples.len() as u32,
532 message: format!(
533 "Found {} inferred triples in namespace '{}'",
534 triples.len(),
535 namespace
536 ),
537 })),
538 Err(e) => Err(Status::internal(e.to_string())),
539 }
540 };
541
542 if let Ok(ref res) = response {
544 let inferred = res.get_ref().triples_inferred as usize;
545 self.audit.log(
546 namespace,
547 &strategy_name,
548 start_triples,
549 inferred,
550 0, vec![], );
553 }
554
555 response
556 }
557}
558
559pub async fn run_mcp_stdio(
560 engine: Arc<MySemanticEngine>,
561) -> Result<(), Box<dyn std::error::Error>> {
562 let server = crate::mcp_stdio::McpStdioServer::new(engine);
563 server.run().await
564}