1use ahash::AHashSet;
2use serde_json::Value;
3
4use crate::{
5 SqliteGraphError,
6 cache::CacheStats,
7 fault_injection::{self, FaultPoint},
8 graph::{ConnectionWrapper, SqliteGraph},
9};
10
11#[derive(Clone, Debug)]
12pub struct GraphEntityCreate {
13 pub kind: String,
14 pub name: String,
15 pub file_path: Option<String>,
16 pub data: Value,
17}
18
19#[derive(Clone, Debug)]
20pub struct GraphEdgeCreate {
21 pub from_id: i64,
22 pub to_id: i64,
23 pub edge_type: String,
24 pub data: Value,
25}
26
27pub struct TransactionGuard<'a> {
32 conn: ConnectionWrapper<'a>,
33 committed: bool,
34}
35
36impl<'a> TransactionGuard<'a> {
37 pub fn new(conn: ConnectionWrapper<'a>) -> Result<Self, SqliteGraphError> {
39 conn.execute("BEGIN IMMEDIATE", [])
40 .map_err(|e| SqliteGraphError::query(e.to_string()))?;
41 Ok(Self {
42 conn,
43 committed: false,
44 })
45 }
46
47 pub fn commit(mut self, graph: &SqliteGraph) -> Result<(), SqliteGraphError> {
49 self.conn
50 .execute("COMMIT", [])
51 .map_err(|e| SqliteGraphError::query(e.to_string()))?;
52 graph.invalidate_caches();
53 graph.update_snapshot();
54 self.committed = true;
55 Ok(())
56 }
57
58 pub fn conn(&self) -> &ConnectionWrapper<'a> {
60 &self.conn
61 }
62
63 pub fn execute<F, R>(mut self, graph: &SqliteGraph, f: F) -> Result<R, SqliteGraphError>
65 where
66 F: FnOnce(&ConnectionWrapper<'a>) -> Result<R, SqliteGraphError>,
67 {
68 match f(&self.conn) {
69 Ok(result) => {
70 self.commit(graph)?;
71 Ok(result)
72 }
73 Err(err) => {
74 self.committed = false; Err(err)
77 }
78 }
79 }
80}
81
82impl<'a> Drop for TransactionGuard<'a> {
83 fn drop(&mut self) {
84 if !self.committed {
85 let _ = self.conn.execute("ROLLBACK", []);
87 }
88 }
89}
90
91pub struct BatchConfig {
93 pub max_batch_size: usize,
94 pub enable_chunking: bool,
95}
96
97impl Default for BatchConfig {
98 fn default() -> Self {
99 Self {
100 max_batch_size: 1000, enable_chunking: true,
102 }
103 }
104}
105
106pub fn execute_batch<T, F, R>(
108 items: &[T],
109 config: &BatchConfig,
110 mut operation: F,
111) -> Result<Vec<R>, SqliteGraphError>
112where
113 F: FnMut(&[T]) -> Result<Vec<R>, SqliteGraphError>,
114{
115 if !config.enable_chunking || items.len() <= config.max_batch_size {
116 return operation(items);
117 }
118
119 let mut all_results = Vec::with_capacity(items.len());
120
121 for chunk in items.chunks(config.max_batch_size) {
123 let chunk_results = operation(chunk)?;
124 all_results.extend(chunk_results);
125 }
126
127 Ok(all_results)
128}
129
130pub fn bulk_insert_entities(
131 graph: &SqliteGraph,
132 entries: &[GraphEntityCreate],
133) -> Result<Vec<i64>, SqliteGraphError> {
134 bulk_insert_entities_with_config(graph, entries, &BatchConfig::default())
135}
136
137pub fn bulk_insert_entities_with_config(
138 graph: &SqliteGraph,
139 entries: &[GraphEntityCreate],
140 config: &BatchConfig,
141) -> Result<Vec<i64>, SqliteGraphError> {
142 if entries.is_empty() {
143 return Ok(Vec::new());
144 }
145
146 execute_batch(entries, config, |chunk| {
147 let conn = graph.connection();
148 TransactionGuard::new(conn)?.execute(graph, |conn| {
149 let mut stmt = conn
150 .prepare_cached(
151 "INSERT INTO graph_entities(kind,name,file_path,data) VALUES(?1,?2,?3,?4)",
152 )
153 .map_err(|e| SqliteGraphError::query(e.to_string()))?;
154 let mut ids = Vec::new();
155 for entry in chunk {
156 validate_entity_create(entry)?;
157 let payload = serde_json::to_string(&entry.data)
158 .map_err(|e| SqliteGraphError::invalid_input(e.to_string()))?;
159 stmt.execute(rusqlite::params![
160 entry.kind,
161 entry.name,
162 entry.file_path,
163 payload
164 ])
165 .map_err(|e| SqliteGraphError::query(e.to_string()))?;
166 ids.push(conn.last_insert_rowid());
167 }
168
169 fault_injection::check_fault(FaultPoint::BulkInsertEntitiesBeforeCommit)?;
171 Ok(ids)
172 })
173 })
174}
175
176pub fn bulk_insert_edges(
177 graph: &SqliteGraph,
178 entries: &[GraphEdgeCreate],
179) -> Result<Vec<i64>, SqliteGraphError> {
180 bulk_insert_edges_with_config(graph, entries, &BatchConfig::default())
181}
182
183pub fn bulk_insert_edges_with_config(
184 graph: &SqliteGraph,
185 entries: &[GraphEdgeCreate],
186 config: &BatchConfig,
187) -> Result<Vec<i64>, SqliteGraphError> {
188 if entries.is_empty() {
189 return Ok(Vec::new());
190 }
191
192 execute_batch(entries, config, |chunk| {
193 let conn = graph.connection();
194 TransactionGuard::new(conn)?.execute(graph, |conn| {
195 let mut stmt = conn
196 .prepare_cached(
197 "INSERT INTO graph_edges(from_id,to_id,edge_type,data) VALUES(?1,?2,?3,?4)",
198 )
199 .map_err(|e| SqliteGraphError::query(e.to_string()))?;
200 let mut ids = Vec::new();
201 let mut seen = AHashSet::new();
202 for entry in chunk {
203 validate_edge_create(entry)?;
204 if !seen.insert((entry.from_id, entry.to_id, entry.edge_type.clone())) {
205 continue;
206 }
207 validate_endpoints_exist(&conn, entry.from_id, entry.to_id)?;
208 let payload = serde_json::to_string(&entry.data)
209 .map_err(|e| SqliteGraphError::invalid_input(e.to_string()))?;
210 stmt.execute(rusqlite::params![
211 entry.from_id,
212 entry.to_id,
213 entry.edge_type,
214 payload
215 ])
216 .map_err(|e| SqliteGraphError::query(e.to_string()))?;
217 ids.push(conn.last_insert_rowid());
218 }
219
220 fault_injection::check_fault(FaultPoint::BulkInsertEdgesBeforeCommit)?;
222 Ok(ids)
223 })
224 })
225}
226
227pub fn adjacency_fetch_outgoing_batch(
228 graph: &SqliteGraph,
229 ids: &[i64],
230) -> Result<Vec<(i64, Vec<i64>)>, SqliteGraphError> {
231 let mut results = Vec::new();
232 for &id in ids {
233 results.push((id, graph.fetch_outgoing(id)?));
234 }
235 results.sort_by(|a, b| a.0.cmp(&b.0));
236 Ok(results)
237}
238
239pub fn adjacency_fetch_incoming_batch(
240 graph: &SqliteGraph,
241 ids: &[i64],
242) -> Result<Vec<(i64, Vec<i64>)>, SqliteGraphError> {
243 let mut results = Vec::new();
244 for &id in ids {
245 results.push((id, graph.fetch_incoming(id)?));
246 }
247 results.sort_by(|a, b| a.0.cmp(&b.0));
248 Ok(results)
249}
250
251pub fn cache_clear_ranges(graph: &SqliteGraph, ids: &[i64]) {
252 for &id in ids {
253 graph.outgoing_cache_ref().remove(id);
254 graph.incoming_cache_ref().remove(id);
255 }
256}
257
258pub fn cache_stats(graph: &SqliteGraph) -> CacheStats {
259 let outgoing = graph.outgoing_cache_ref().stats();
260 let incoming = graph.incoming_cache_ref().stats();
261 CacheStats {
262 hits: outgoing.hits + incoming.hits,
263 misses: outgoing.misses + incoming.misses,
264 entries: outgoing.entries + incoming.entries,
265 }
266}
267
268fn validate_entity_create(entry: &GraphEntityCreate) -> Result<(), SqliteGraphError> {
269 if entry.kind.trim().is_empty() {
270 return Err(SqliteGraphError::invalid_input("entity kind must be set"));
271 }
272 if entry.name.trim().is_empty() {
273 return Err(SqliteGraphError::invalid_input("entity name must be set"));
274 }
275 Ok(())
276}
277
278fn validate_edge_create(entry: &GraphEdgeCreate) -> Result<(), SqliteGraphError> {
279 if entry.edge_type.trim().is_empty() {
280 return Err(SqliteGraphError::invalid_input("edge type must be set"));
281 }
282 if entry.from_id <= 0 || entry.to_id <= 0 {
283 return Err(SqliteGraphError::invalid_input(
284 "edge endpoints must be positive ids",
285 ));
286 }
287 Ok(())
288}
289
290fn validate_endpoints_exist(
291 conn: &ConnectionWrapper<'_>,
292 from: i64,
293 to: i64,
294) -> Result<(), SqliteGraphError> {
295 let mut stmt = conn
296 .prepare_cached("SELECT COUNT(1) FROM graph_entities WHERE id IN (?1, ?2)")
297 .map_err(|e| SqliteGraphError::query(e.to_string()))?;
298 let count: i64 = stmt
299 .query_row(rusqlite::params![from, to], |row| row.get(0))
300 .map_err(|e| SqliteGraphError::query(e.to_string()))?;
301 if count < 2 {
302 return Err(SqliteGraphError::invalid_input("edge endpoints must exist"));
303 }
304 Ok(())
305}