Skip to main content

synapse_core/
store.rs

1use crate::persistence::{load_bincode, save_bincode};
2use crate::vector_store::VectorStore;
3use anyhow::Result;
4use oxigraph::model::*;
5use oxigraph::store::Store;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, RwLock};
11use uuid::Uuid;
12
13const DEFAULT_MAPPING_SAVE_THRESHOLD: usize = 1000;
14
15/// Persisted URI mappings
16#[derive(Serialize, Deserialize, Default)]
17struct UriMappings {
18    uri_to_id: HashMap<String, u32>,
19    next_id: u32,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
23pub struct Provenance {
24    pub source: String,
25    pub timestamp: String,
26    pub method: String,
27}
28
29pub struct IngestTriple {
30    pub subject: String,
31    pub predicate: String,
32    pub object: String,
33    pub provenance: Option<Provenance>,
34}
35
36pub struct SynapseStore {
37    pub store: Store,
38    pub namespace: String,
39    pub storage_path: PathBuf,
40    // Mapping for gRPC compatibility (ID <-> URI)
41    pub id_to_uri: RwLock<HashMap<u32, String>>,
42    pub uri_to_id: RwLock<HashMap<String, u32>>,
43    pub next_id: std::sync::atomic::AtomicU32,
44    // Vector store for hybrid search
45    pub vector_store: Option<Arc<VectorStore>>,
46    // Persistence state
47    dirty_count: AtomicUsize,
48    save_threshold: usize,
49}
50
51impl SynapseStore {
52    pub fn open(namespace: &str, storage_path: &str) -> Result<Self> {
53        let path = PathBuf::from(storage_path).join(namespace);
54        std::fs::create_dir_all(&path)?;
55        let store = Store::open(&path)?;
56
57        // Load persisted URI mappings if they exist
58        let mappings_path_bin = path.join("uri_mappings.bin");
59        let mappings_path_json = path.join("uri_mappings.json");
60
61        let (uri_to_id, id_to_uri, next_id) = if mappings_path_bin.exists() {
62            let mappings: UriMappings = load_bincode(&mappings_path_bin)?;
63            let id_to_uri: HashMap<u32, String> = mappings
64                .uri_to_id
65                .iter()
66                .map(|(uri, &id)| (id, uri.clone()))
67                .collect();
68            (mappings.uri_to_id, id_to_uri, mappings.next_id)
69        } else if mappings_path_json.exists() {
70            let content = std::fs::read_to_string(&mappings_path_json)?;
71            let mappings: UriMappings = serde_json::from_str(&content)?;
72            let id_to_uri: HashMap<u32, String> = mappings
73                .uri_to_id
74                .iter()
75                .map(|(uri, &id)| (id, uri.clone()))
76                .collect();
77            (mappings.uri_to_id, id_to_uri, mappings.next_id)
78        } else {
79            (HashMap::new(), HashMap::new(), 1)
80        };
81
82        // Initialize vector store (optional, can fail gracefully)
83        let vector_store = VectorStore::new(namespace).ok().map(Arc::new);
84
85        Ok(Self {
86            store,
87            namespace: namespace.to_string(),
88            storage_path: path,
89            id_to_uri: RwLock::new(id_to_uri),
90            uri_to_id: RwLock::new(uri_to_id),
91            next_id: std::sync::atomic::AtomicU32::new(next_id),
92            vector_store,
93            dirty_count: AtomicUsize::new(0),
94            save_threshold: DEFAULT_MAPPING_SAVE_THRESHOLD,
95        })
96    }
97
98    /// Save URI mappings to disk
99    fn save_mappings(&self) -> Result<()> {
100        let mappings = UriMappings {
101            uri_to_id: self.uri_to_id.read().unwrap().clone(),
102            next_id: self.next_id.load(std::sync::atomic::Ordering::Relaxed),
103        };
104        // Capture the count before saving? No, we just care that we saved the current state.
105        // But if new items are added during save, the dirty count will increment.
106        // We need to subtract what we think we saved.
107        // Since we save the *entire* map, we effectively save *all* dirty items up to that point.
108        // So we can read the dirty count, save, then subtract.
109        let current_dirty = self.dirty_count.load(Ordering::Relaxed);
110
111        save_bincode(&self.storage_path.join("uri_mappings.bin"), &mappings)?;
112
113        if current_dirty > 0 {
114            let _ = self.dirty_count.fetch_sub(current_dirty, Ordering::Relaxed);
115        }
116        Ok(())
117    }
118
119    /// Force save all data to disk
120    pub fn flush(&self) -> Result<()> {
121        self.save_mappings()?;
122        if let Some(ref vs) = self.vector_store {
123            vs.flush()?;
124        }
125        Ok(())
126    }
127
128    pub fn get_or_create_id(&self, uri: &str) -> u32 {
129        {
130            let map = self.uri_to_id.read().unwrap();
131            if let Some(&id) = map.get(uri) {
132                return id;
133            }
134        }
135
136        let mut uri_map = self.uri_to_id.write().unwrap();
137        let mut id_map = self.id_to_uri.write().unwrap();
138
139        if let Some(&id) = uri_map.get(uri) {
140            return id;
141        }
142
143        let id = self
144            .next_id
145            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
146        uri_map.insert(uri.to_string(), id);
147        id_map.insert(id, uri.to_string());
148
149        drop(uri_map);
150        drop(id_map);
151
152        // Check if we need to auto-save mappings
153        let count = self.dirty_count.fetch_add(1, Ordering::Relaxed);
154        if count + 1 >= self.save_threshold {
155            let _ = self.save_mappings();
156        }
157
158        id
159    }
160
161    pub fn get_uri(&self, id: u32) -> Option<String> {
162        self.id_to_uri.read().unwrap().get(&id).cloned()
163    }
164
165    pub async fn ingest_triples(&self, triples: Vec<IngestTriple>) -> Result<(u32, u32)> {
166        let mut added = 0;
167
168        // Group by provenance to optimize batch insertion into named graphs
169        let mut batches: HashMap<Option<Provenance>, Vec<(String, String, String)>> =
170            HashMap::new();
171
172        for t in triples {
173            batches
174                .entry(t.provenance)
175                .or_default()
176                .push((t.subject, t.predicate, t.object));
177        }
178
179        for (prov, batch_triples) in batches {
180            let graph_name = if let Some(p) = &prov {
181                let uuid = Uuid::new_v4();
182                let uri = format!("urn:batch:{}", uuid);
183
184                let batch_node = NamedNode::new_unchecked(&uri);
185                let p_derived =
186                    NamedNode::new_unchecked("http://www.w3.org/ns/prov#wasDerivedFrom");
187                let p_time = NamedNode::new_unchecked("http://www.w3.org/ns/prov#generatedAtTime");
188                let p_method = NamedNode::new_unchecked("http://www.w3.org/ns/prov#wasGeneratedBy");
189
190                let o_source = Literal::new_simple_literal(&p.source);
191                let o_time = Literal::new_simple_literal(&p.timestamp);
192                let o_method = Literal::new_simple_literal(&p.method);
193
194                self.store.insert(&Quad::new(
195                    batch_node.clone(),
196                    p_derived,
197                    o_source,
198                    GraphName::DefaultGraph,
199                ))?;
200                self.store.insert(&Quad::new(
201                    batch_node.clone(),
202                    p_time,
203                    o_time,
204                    GraphName::DefaultGraph,
205                ))?;
206                self.store.insert(&Quad::new(
207                    batch_node.clone(),
208                    p_method,
209                    o_method,
210                    GraphName::DefaultGraph,
211                ))?;
212
213                GraphName::NamedNode(batch_node)
214            } else {
215                GraphName::DefaultGraph
216            };
217
218            for (s, p, o) in batch_triples {
219                let subject_uri = self.ensure_uri(&s);
220                let predicate_uri = self.ensure_uri(&p);
221                let object_uri = self.ensure_uri(&o);
222
223                // Register URIs in the ID mapping (for gRPC compatibility)
224                self.get_or_create_id(&subject_uri);
225                self.get_or_create_id(&predicate_uri);
226                self.get_or_create_id(&object_uri);
227
228                let subject = Subject::NamedNode(NamedNode::new_unchecked(&subject_uri));
229                let predicate = NamedNode::new_unchecked(&predicate_uri);
230                let object = Term::NamedNode(NamedNode::new_unchecked(&object_uri));
231
232                let quad = Quad::new(subject, predicate, object, graph_name.clone());
233                if self.store.insert(&quad)? {
234                    // Also index in vector store if available
235                    if let Some(ref vs) = self.vector_store {
236                        // Create searchable content from triple
237                        let content = format!("{} {} {}", s, p, o);
238                        // Use a deterministic hash/key for the triple to allow multiple triples per subject
239                        // We use the content itself as key or a hash of it.
240                        // Ideally we should use a hash, but for simplicity let's use the formatted content string as key prefix?
241                        // Actually, just using a unique ID is fine, but we want idempotency.
242                        // format!("{}|{}|{}", s, p, o)
243                        let key = format!("{}|{}|{}", subject_uri, predicate_uri, object_uri);
244
245                        // Pass metadata including the subject URI for graph expansion later
246                        let metadata = serde_json::json!({
247                            "uri": subject_uri,
248                            "predicate": predicate_uri,
249                            "object": object_uri,
250                            "type": "triple"
251                        });
252
253                        if let Err(e) = vs.add(&key, &content, metadata).await {
254                            // Rollback graph insertion to ensure consistency
255                            self.store.remove(&quad)?;
256                            return Err(anyhow::anyhow!(
257                                "Vector store insertion failed, rolled back graph change: {}",
258                                e
259                            ));
260                        }
261                    }
262                    added += 1;
263                }
264            }
265        }
266
267        Ok((added, 0))
268    }
269
270    /// Hybrid search: vector similarity + graph expansion
271    pub async fn hybrid_search(
272        &self,
273        query: &str,
274        vector_k: usize,
275        graph_depth: u32,
276    ) -> Result<Vec<(String, f32)>> {
277        let mut results = Vec::new();
278
279        // Step 1: Vector search
280        if let Some(ref vs) = self.vector_store {
281            let vector_results = vs.search(query, vector_k).await?;
282
283            for result in vector_results {
284                // Use the URI from metadata/result (which maps to Subject URI for triples)
285                let uri = result.uri.clone();
286                results.push((uri.clone(), result.score));
287
288                // Step 2: Graph expansion (if depth > 0)
289                if graph_depth > 0 {
290                    let expanded = self.expand_graph(&uri, graph_depth)?;
291                    for expanded_uri in expanded {
292                        // Add with slightly lower score
293                        results.push((expanded_uri, result.score * 0.8));
294                    }
295                }
296            }
297        }
298
299        // Remove duplicates and sort by score
300        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
301        results.dedup_by(|a, b| a.0 == b.0);
302
303        Ok(results)
304    }
305
306    /// Expand graph from a starting URI
307    fn expand_graph(&self, start_uri: &str, depth: u32) -> Result<Vec<String>> {
308        let mut expanded = Vec::new();
309
310        if depth == 0 {
311            return Ok(expanded);
312        }
313
314        // Query for all triples where start_uri is subject or object
315        let subject = NamedNodeRef::new(start_uri).ok();
316
317        if let Some(subj) = subject {
318            for q in self
319                .store
320                .quads_for_pattern(Some(subj.into()), None, None, None)
321                .flatten()
322            {
323                expanded.push(q.object.to_string());
324
325                // Recursive expansion (simplified, depth-1)
326                if depth > 1 {
327                    let nested = self.expand_graph(&q.object.to_string(), depth - 1)?;
328                    expanded.extend(nested);
329                }
330            }
331        }
332
333        Ok(expanded)
334    }
335
336    pub fn query_sparql(&self, query: &str) -> Result<String> {
337        use oxigraph::sparql::QueryResults;
338
339        let results = self.store.query(query)?;
340
341        match results {
342            QueryResults::Solutions(solutions) => {
343                let mut results_array = Vec::new();
344                for solution in solutions {
345                    let sol = solution?;
346                    let mut mapping = serde_json::Map::new();
347                    for (variable, value) in sol.iter() {
348                        mapping.insert(
349                            variable.to_string(),
350                            serde_json::to_value(value.to_string()).unwrap(),
351                        );
352                    }
353                    results_array.push(serde_json::Value::Object(mapping));
354                }
355                Ok(serde_json::to_string(&results_array)?)
356            }
357            _ => Ok("[]".to_string()),
358        }
359    }
360
361    pub fn get_degree(&self, uri: &str) -> usize {
362        let node = NamedNodeRef::new(uri).ok();
363        if let Some(n) = node {
364            let outgoing = self
365                .store
366                .quads_for_pattern(Some(n.into()), None, None, None)
367                .count();
368            let incoming = self
369                .store
370                .quads_for_pattern(None, None, Some(n.into()), None)
371                .count();
372            outgoing + incoming
373        } else {
374            0
375        }
376    }
377
378    pub fn ensure_uri(&self, s: &str) -> String {
379        if s.starts_with("http") || s.starts_with("urn:") {
380            s.to_string()
381        } else {
382            format!("http://synapse.os/{}", s)
383        }
384    }
385}