Skip to main content

sqlitegraph/
graph_opt.rs

1use ahash::AHashSet;
2use serde_json::Value;
3
4use crate::{
5    SqliteGraphError,
6    cache::CacheStats,
7    fault_injection::{self, FaultPoint},
8    graph::{ConnectionWrapper, SqliteGraph},
9};
10
11#[derive(Clone, Debug)]
12pub struct GraphEntityCreate {
13    pub kind: String,
14    pub name: String,
15    pub file_path: Option<String>,
16    pub data: Value,
17}
18
19#[derive(Clone, Debug)]
20pub struct GraphEdgeCreate {
21    pub from_id: i64,
22    pub to_id: i64,
23    pub edge_type: String,
24    pub data: Value,
25}
26
27/// Transaction safety wrapper for automatic rollback on errors
28///
29/// This wrapper holds onto a connection wrapper for the duration of a transaction.
30/// Note: Due to lifetime constraints, this is typically used within a function scope.
31pub struct TransactionGuard<'a> {
32    conn: ConnectionWrapper<'a>,
33    committed: bool,
34}
35
36impl<'a> TransactionGuard<'a> {
37    /// Start a new transaction with IMMEDIATE mode for better write performance
38    pub fn new(conn: ConnectionWrapper<'a>) -> Result<Self, SqliteGraphError> {
39        conn.execute("BEGIN IMMEDIATE", [])
40            .map_err(|e| SqliteGraphError::query(e.to_string()))?;
41        Ok(Self {
42            conn,
43            committed: false,
44        })
45    }
46
47    /// Commit the transaction with cache invalidation and snapshot update
48    pub fn commit(mut self, graph: &SqliteGraph) -> Result<(), SqliteGraphError> {
49        self.conn
50            .execute("COMMIT", [])
51            .map_err(|e| SqliteGraphError::query(e.to_string()))?;
52        graph.invalidate_caches();
53        graph.update_snapshot();
54        self.committed = true;
55        Ok(())
56    }
57
58    /// Get reference to the underlying connection wrapper
59    pub fn conn(&self) -> &ConnectionWrapper<'a> {
60        &self.conn
61    }
62
63    /// Execute a function with automatic rollback on error
64    pub fn execute<F, R>(mut self, graph: &SqliteGraph, f: F) -> Result<R, SqliteGraphError>
65    where
66        F: FnOnce(&ConnectionWrapper<'a>) -> Result<R, SqliteGraphError>,
67    {
68        match f(&self.conn) {
69            Ok(result) => {
70                self.commit(graph)?;
71                Ok(result)
72            }
73            Err(err) => {
74                // Don't rollback here - Drop will handle it automatically
75                self.committed = false; // Ensure Drop knows to rollback
76                Err(err)
77            }
78        }
79    }
80}
81
82impl<'a> Drop for TransactionGuard<'a> {
83    fn drop(&mut self) {
84        if !self.committed {
85            // Auto-rollback if not explicitly committed
86            let _ = self.conn.execute("ROLLBACK", []);
87        }
88    }
89}
90
91/// Configuration for batch operations
92pub struct BatchConfig {
93    pub max_batch_size: usize,
94    pub enable_chunking: bool,
95}
96
97impl Default for BatchConfig {
98    fn default() -> Self {
99        Self {
100            max_batch_size: 1000, // Conservative default for WAL mode
101            enable_chunking: true,
102        }
103    }
104}
105
106/// Execute a batch operation with automatic chunking for large datasets
107pub fn execute_batch<T, F, R>(
108    items: &[T],
109    config: &BatchConfig,
110    mut operation: F,
111) -> Result<Vec<R>, SqliteGraphError>
112where
113    F: FnMut(&[T]) -> Result<Vec<R>, SqliteGraphError>,
114{
115    if !config.enable_chunking || items.len() <= config.max_batch_size {
116        return operation(items);
117    }
118
119    let mut all_results = Vec::with_capacity(items.len());
120
121    // Process in deterministic chunks to maintain ordering
122    for chunk in items.chunks(config.max_batch_size) {
123        let chunk_results = operation(chunk)?;
124        all_results.extend(chunk_results);
125    }
126
127    Ok(all_results)
128}
129
130pub fn bulk_insert_entities(
131    graph: &SqliteGraph,
132    entries: &[GraphEntityCreate],
133) -> Result<Vec<i64>, SqliteGraphError> {
134    bulk_insert_entities_with_config(graph, entries, &BatchConfig::default())
135}
136
137pub fn bulk_insert_entities_with_config(
138    graph: &SqliteGraph,
139    entries: &[GraphEntityCreate],
140    config: &BatchConfig,
141) -> Result<Vec<i64>, SqliteGraphError> {
142    if entries.is_empty() {
143        return Ok(Vec::new());
144    }
145
146    execute_batch(entries, config, |chunk| {
147        let conn = graph.connection();
148        TransactionGuard::new(conn)?.execute(graph, |conn| {
149            let mut stmt = conn
150                .prepare_cached(
151                    "INSERT INTO graph_entities(kind,name,file_path,data) VALUES(?1,?2,?3,?4)",
152                )
153                .map_err(|e| SqliteGraphError::query(e.to_string()))?;
154            let mut ids = Vec::new();
155            for entry in chunk {
156                validate_entity_create(entry)?;
157                let payload = serde_json::to_string(&entry.data)
158                    .map_err(|e| SqliteGraphError::invalid_input(e.to_string()))?;
159                stmt.execute(rusqlite::params![
160                    entry.kind,
161                    entry.name,
162                    entry.file_path,
163                    payload
164                ])
165                .map_err(|e| SqliteGraphError::query(e.to_string()))?;
166                ids.push(conn.last_insert_rowid());
167            }
168
169            // Check for fault injection before commit
170            fault_injection::check_fault(FaultPoint::BulkInsertEntitiesBeforeCommit)?;
171            Ok(ids)
172        })
173    })
174}
175
176pub fn bulk_insert_edges(
177    graph: &SqliteGraph,
178    entries: &[GraphEdgeCreate],
179) -> Result<Vec<i64>, SqliteGraphError> {
180    bulk_insert_edges_with_config(graph, entries, &BatchConfig::default())
181}
182
183pub fn bulk_insert_edges_with_config(
184    graph: &SqliteGraph,
185    entries: &[GraphEdgeCreate],
186    config: &BatchConfig,
187) -> Result<Vec<i64>, SqliteGraphError> {
188    if entries.is_empty() {
189        return Ok(Vec::new());
190    }
191
192    execute_batch(entries, config, |chunk| {
193        let conn = graph.connection();
194        TransactionGuard::new(conn)?.execute(graph, |conn| {
195            let mut stmt = conn
196                .prepare_cached(
197                    "INSERT INTO graph_edges(from_id,to_id,edge_type,data) VALUES(?1,?2,?3,?4)",
198                )
199                .map_err(|e| SqliteGraphError::query(e.to_string()))?;
200            let mut ids = Vec::new();
201            let mut seen = AHashSet::new();
202            for entry in chunk {
203                validate_edge_create(entry)?;
204                if !seen.insert((entry.from_id, entry.to_id, entry.edge_type.clone())) {
205                    continue;
206                }
207                validate_endpoints_exist(&conn, entry.from_id, entry.to_id)?;
208                let payload = serde_json::to_string(&entry.data)
209                    .map_err(|e| SqliteGraphError::invalid_input(e.to_string()))?;
210                stmt.execute(rusqlite::params![
211                    entry.from_id,
212                    entry.to_id,
213                    entry.edge_type,
214                    payload
215                ])
216                .map_err(|e| SqliteGraphError::query(e.to_string()))?;
217                ids.push(conn.last_insert_rowid());
218            }
219
220            // Check for fault injection before commit
221            fault_injection::check_fault(FaultPoint::BulkInsertEdgesBeforeCommit)?;
222            Ok(ids)
223        })
224    })
225}
226
227pub fn adjacency_fetch_outgoing_batch(
228    graph: &SqliteGraph,
229    ids: &[i64],
230) -> Result<Vec<(i64, Vec<i64>)>, SqliteGraphError> {
231    let mut results = Vec::new();
232    for &id in ids {
233        results.push((id, graph.fetch_outgoing(id)?));
234    }
235    results.sort_by(|a, b| a.0.cmp(&b.0));
236    Ok(results)
237}
238
239pub fn adjacency_fetch_incoming_batch(
240    graph: &SqliteGraph,
241    ids: &[i64],
242) -> Result<Vec<(i64, Vec<i64>)>, SqliteGraphError> {
243    let mut results = Vec::new();
244    for &id in ids {
245        results.push((id, graph.fetch_incoming(id)?));
246    }
247    results.sort_by(|a, b| a.0.cmp(&b.0));
248    Ok(results)
249}
250
251pub fn cache_clear_ranges(graph: &SqliteGraph, ids: &[i64]) {
252    for &id in ids {
253        graph.outgoing_cache_ref().remove(id);
254        graph.incoming_cache_ref().remove(id);
255    }
256}
257
258pub fn cache_stats(graph: &SqliteGraph) -> CacheStats {
259    let outgoing = graph.outgoing_cache_ref().stats();
260    let incoming = graph.incoming_cache_ref().stats();
261    CacheStats {
262        hits: outgoing.hits + incoming.hits,
263        misses: outgoing.misses + incoming.misses,
264        entries: outgoing.entries + incoming.entries,
265    }
266}
267
268fn validate_entity_create(entry: &GraphEntityCreate) -> Result<(), SqliteGraphError> {
269    if entry.kind.trim().is_empty() {
270        return Err(SqliteGraphError::invalid_input("entity kind must be set"));
271    }
272    if entry.name.trim().is_empty() {
273        return Err(SqliteGraphError::invalid_input("entity name must be set"));
274    }
275    Ok(())
276}
277
278fn validate_edge_create(entry: &GraphEdgeCreate) -> Result<(), SqliteGraphError> {
279    if entry.edge_type.trim().is_empty() {
280        return Err(SqliteGraphError::invalid_input("edge type must be set"));
281    }
282    if entry.from_id <= 0 || entry.to_id <= 0 {
283        return Err(SqliteGraphError::invalid_input(
284            "edge endpoints must be positive ids",
285        ));
286    }
287    Ok(())
288}
289
290fn validate_endpoints_exist(
291    conn: &ConnectionWrapper<'_>,
292    from: i64,
293    to: i64,
294) -> Result<(), SqliteGraphError> {
295    let mut stmt = conn
296        .prepare_cached("SELECT COUNT(1) FROM graph_entities WHERE id IN (?1, ?2)")
297        .map_err(|e| SqliteGraphError::query(e.to_string()))?;
298    let count: i64 = stmt
299        .query_row(rusqlite::params![from, to], |row| row.get(0))
300        .map_err(|e| SqliteGraphError::query(e.to_string()))?;
301    if count < 2 {
302        return Err(SqliteGraphError::invalid_input("edge endpoints must exist"));
303    }
304    Ok(())
305}