Skip to main content

xz_knowledge_graph/store/
sqlite.rs

1use std::cmp::Ordering;
2use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
3
4/// Priority queue entry for Dijkstra shortest_path. Min-heap via PartialOrd override.
5#[derive(PartialEq)]
6struct PathCost(f32, String);
7
8impl Eq for PathCost {}
9
10impl PartialOrd for PathCost {
11    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
12        other.0.partial_cmp(&self.0)
13    }
14}
15
16impl Ord for PathCost {
17    fn cmp(&self, other: &Self) -> Ordering {
18        self.partial_cmp(other).unwrap_or(Ordering::Equal)
19    }
20}
21
22use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
23use tracing::{debug, info};
24
25use crate::config::KgConfig;
26use crate::error::KgError;
27use crate::store::sqlite_schema::{DDL, FTS_TRIGGERS};
28use crate::traits::KnowledgeGraph;
29use crate::types::attribute::AttributeValue;
30use crate::types::confidence::Confidence;
31use crate::types::consistency::{ConsistencyIssue, ConsistencyIssueType, IssueSeverity};
32use crate::types::entity::{Entity, EntityType};
33use crate::types::graph::{GraphStats, PathStep, SubGraph};
34use crate::types::import::{ImportResult, MergeStrategy, UpsertResult};
35use crate::types::provenance::Provenance;
36use crate::types::query::{
37    EntityPage, EntityQuery, RelationQuery,
38};
39use crate::types::relation::{Relation, WeightStrategy};
40
41/// SQLite-backed knowledge graph implementation.
42#[derive(Debug)]
43pub struct SqliteKnowledgeGraph {
44    pool: SqlitePool,
45    #[allow(dead_code)]
46    merge_strategy: MergeStrategy,
47    weight_strategy: WeightStrategy,
48    max_bfs_depth: u32,
49    max_path_search: u32,
50}
51
52impl SqliteKnowledgeGraph {
53    pub async fn new(path: &str, config: KgConfig) -> Result<Self, KgError> {
54        let pool = SqlitePoolOptions::new()
55            .max_connections(config.storage.pool_size)
56            .connect(&format!("sqlite:{}", path))
57            .await
58            .map_err(|e| KgError::Database(e.to_string()))?;
59
60        sqlx::query("PRAGMA journal_mode=WAL")
61            .execute(&pool)
62            .await
63            .map_err(|e| KgError::Database(e.to_string()))?;
64
65        let this = Self {
66            pool,
67            merge_strategy: config.merge_strategy,
68            weight_strategy: config.weight_strategy,
69            max_bfs_depth: config.max_bfs_depth,
70            max_path_search: config.max_path_search,
71        };
72
73        this.run_migrations().await?;
74        Ok(this)
75    }
76
77    async fn run_migrations(&self) -> Result<(), KgError> {
78        for stmt in DDL {
79            sqlx::query(stmt)
80                .execute(&self.pool)
81                .await
82                .map_err(|e| KgError::Database(format!("Migration failed: {}", e)))?;
83        }
84        for stmt in FTS_TRIGGERS {
85            let _ = sqlx::query(stmt).execute(&self.pool).await;
86        }
87        debug!("sqlite schema migrations complete");
88        Ok(())
89    }
90}
91
92#[async_trait::async_trait]
93impl KnowledgeGraph for SqliteKnowledgeGraph {
94    // === Entity Operations ===
95
96    async fn upsert_entity(&self, entity: Entity) -> Result<UpsertResult, KgError> {
97        let existing: Option<EntityRow> = sqlx::query_as(
98            "SELECT id, name, entity_type, attributes_json, description, created_at, updated_at,
99                    version, source, tags_json, aliases_json
100             FROM entities WHERE id = ?",
101        )
102        .bind(&entity.id)
103        .fetch_optional(&self.pool)
104        .await
105        .map_err(|e| KgError::Database(e.to_string()))?;
106
107        let attrs_json =
108            serde_json::to_string(&entity.attributes).map_err(|e| KgError::Serialization(e.to_string()))?;
109        let tags_json =
110            serde_json::to_string(&entity.tags).map_err(|e| KgError::Serialization(e.to_string()))?;
111        let aliases_json =
112            serde_json::to_string(&entity.aliases).map_err(|e| KgError::Serialization(e.to_string()))?;
113        let entity_type = entity.entity_type.as_str();
114
115        if let Some(row) = existing {
116            let mut changed = Vec::new();
117            let conflicts = Vec::new();
118
119            if row.name != entity.name {
120                changed.push("name".into());
121            }
122            if row.entity_type != entity_type {
123                changed.push("entity_type".into());
124            }
125
126            if changed.is_empty() {
127                return Ok(UpsertResult::Unchanged);
128            }
129
130            sqlx::query(
131                "UPDATE entities SET name=?, entity_type=?, attributes_json=?, description=?,
132                 updated_at=?, version=version+1, source=?, tags_json=?, aliases_json=?
133                 WHERE id=?",
134            )
135            .bind(&entity.name)
136            .bind(&entity_type)
137            .bind(&attrs_json)
138            .bind(&entity.description)
139            .bind(current_epoch_ms() as i64)
140            .bind(&entity.source)
141            .bind(&tags_json)
142            .bind(&aliases_json)
143            .bind(&entity.id)
144            .execute(&self.pool)
145            .await
146            .map_err(|e| KgError::Database(e.to_string()))?;
147
148            Ok(UpsertResult::Updated { changed_fields: changed, conflicts })
149        } else {
150            sqlx::query(
151                "INSERT INTO entities (id, name, entity_type, attributes_json, description,
152                 created_at, updated_at, version, source, tags_json, aliases_json)
153                 VALUES (?, ?, ?, ?, ?, ?, ?, 1, ?, ?, ?)",
154            )
155            .bind(&entity.id)
156            .bind(&entity.name)
157            .bind(&entity_type)
158            .bind(&attrs_json)
159            .bind(&entity.description)
160            .bind(entity.created_at as i64)
161            .bind(entity.updated_at as i64)
162            .bind(&entity.source)
163            .bind(&tags_json)
164            .bind(&aliases_json)
165            .execute(&self.pool)
166            .await
167            .map_err(|e| KgError::Database(e.to_string()))?;
168
169            Ok(UpsertResult::Created)
170        }
171    }
172
173    async fn get_entity(&self, id: &str) -> Result<Option<Entity>, KgError> {
174        let row: Option<EntityRow> = sqlx::query_as(
175            "SELECT id, name, entity_type, attributes_json, description, created_at, updated_at,
176                    version, source, tags_json, aliases_json
177             FROM entities WHERE id = ?",
178        )
179        .bind(id)
180        .fetch_optional(&self.pool)
181        .await
182        .map_err(|e| KgError::Database(e.to_string()))?;
183
184        Ok(row.map(|r| r.into()))
185    }
186
187    async fn search_entities(&self, query: &EntityQuery) -> Result<EntityPage, KgError> {
188        let use_fts = query.name_contains.is_some();
189
190        let select_cols = "e.id, e.name, e.entity_type, e.attributes_json, e.description, \
191            e.created_at, e.updated_at, e.version, e.source, e.tags_json, e.aliases_json";
192
193        let mut sql = if use_fts {
194            format!(
195                "SELECT {} FROM entities e JOIN entities_fts fts ON e.rowid = fts.rowid WHERE entities_fts MATCH ?",
196                select_cols
197            )
198        } else {
199            format!("SELECT {} FROM entities e WHERE 1=1", select_cols)
200        };
201        let mut params: Vec<String> = Vec::new();
202
203        if use_fts {
204            // FTS5 query: append * for prefix matching
205            let fts_query = format!("{}*", query.name_contains.as_ref().unwrap());
206            params.push(fts_query);
207        } else if let Some(ref name) = query.name_contains {
208            params.push(format!("%{}%", name));
209            sql.push_str(" AND e.name LIKE ?");
210        }
211
212        if let Some(ref aliases) = query.alias_contains {
213            params.push(format!("%{}%", aliases));
214            sql.push_str(" AND e.aliases_json LIKE ?");
215        }
216        if let Some(ref types) = query.entity_types {
217            if !types.is_empty() {
218                let type_strs: Vec<String> = types.iter().map(|t| t.as_str()).collect();
219                let placeholders: Vec<String> = type_strs.iter().map(|_| "?".to_string()).collect();
220                sql.push_str(&format!(" AND e.entity_type IN ({})", placeholders.join(",")));
221                params.extend(type_strs);
222            }
223        }
224        if let Some(ref source) = query.source {
225            params.push(source.clone());
226            sql.push_str(" AND e.source = ?");
227        }
228        // Tag filter
229        if let Some(ref tag_filter) = query.tags {
230            if !tag_filter.tags.is_empty() {
231                match tag_filter.mode {
232                    crate::types::query::TagFilterMode::Or => {
233                        let tag_conditions: Vec<String> = tag_filter.tags.iter().map(|_| {
234                            "e.tags_json LIKE '%' || ? || '%'".to_string()
235                        }).collect();
236                        sql.push_str(&format!(" AND ({})", tag_conditions.join(" OR ")));
237                        params.extend(tag_filter.tags.iter().cloned());
238                    }
239                    crate::types::query::TagFilterMode::And => {
240                        for tag in &tag_filter.tags {
241                            sql.push_str(" AND e.tags_json LIKE '%' || ? || '%'");
242                            params.push(tag.clone());
243                        }
244                    }
245                }
246            }
247        }
248        // Attribute filters
249        for attr in &query.attribute_filters {
250            let json_path = format!("$.{}", attr.key);
251            match attr.operator {
252                crate::types::query::FilterOperator::Eq => {
253                    sql.push_str(" AND json_extract(e.attributes_json, ?) = ?");
254                    params.push(json_path);
255                    params.push(attr.value.clone());
256                }
257                crate::types::query::FilterOperator::Contains => {
258                    sql.push_str(" AND json_extract(e.attributes_json, ?) LIKE '%' || ? || '%'");
259                    params.push(json_path);
260                    params.push(attr.value.clone());
261                }
262                _ => {
263                    // Other operators: use JSON value comparison
264                    sql.push_str(" AND json_extract(e.attributes_json, ?) = ?");
265                    params.push(json_path);
266                    params.push(attr.value.clone());
267                }
268            }
269        }
270
271        // Count
272        let count_sql = sql.replace(
273            &format!("SELECT {}", select_cols),
274            "SELECT COUNT(*)",
275        );
276
277        let mut count_query = sqlx::query_scalar(&count_sql);
278        for p in &params {
279            count_query = count_query.bind(p);
280        }
281        let total: i64 = count_query
282            .fetch_one(&self.pool)
283            .await
284            .map_err(|e| KgError::Database(e.to_string()))?;
285
286        // Sort
287        let order = match query.sort_by {
288            Some(crate::types::query::EntitySortField::Name) => "e.name ASC",
289            Some(crate::types::query::EntitySortField::CreatedAt) => "e.created_at DESC",
290            Some(crate::types::query::EntitySortField::UpdatedAt) => "e.updated_at DESC",
291            Some(crate::types::query::EntitySortField::EntityType) => "e.entity_type ASC",
292            Some(crate::types::query::EntitySortField::RelationCount) => "e.updated_at DESC",
293            None => {
294                if use_fts { "ORDER BY rank" } else { "ORDER BY e.updated_at DESC" }
295            }
296        };
297        sql.push_str(&format!(" {} LIMIT ? OFFSET ?", order));
298
299        let mut fetch_query = sqlx::query_as::<_, EntityRow>(&sql);
300        for p in &params {
301            fetch_query = fetch_query.bind(p);
302        }
303        fetch_query = fetch_query
304            .bind(query.page.limit as i64)
305            .bind(query.page.offset as i64);
306
307        let rows: Vec<EntityRow> = fetch_query
308            .fetch_all(&self.pool)
309            .await
310            .map_err(|e| KgError::Database(e.to_string()))?;
311
312        let total = total as usize;
313        let items: Vec<Entity> = rows.into_iter().map(|r| r.into()).collect();
314        let has_more = query.page.offset + query.page.limit < total;
315
316        Ok(EntityPage { items, total, has_more })
317    }
318
319    async fn delete_entity(&self, id: &str) -> Result<usize, KgError> {
320        let mut txn = self.pool.begin().await.map_err(|e| KgError::Database(e.to_string()))?;
321
322        let outcome = {
323            let conn = std::ops::DerefMut::deref_mut(&mut txn);
324
325            let relation_count: (i64,) = sqlx::query_as(
326                "SELECT COUNT(*) FROM relations WHERE source_id = ? OR target_id = ?",
327            )
328            .bind(id)
329            .bind(id)
330            .fetch_one(&mut *conn)
331            .await
332            .map_err(|e| KgError::Database(e.to_string()))?;
333
334            sqlx::query("DELETE FROM relations WHERE source_id = ? OR target_id = ?")
335                .bind(id)
336                .bind(id)
337                .execute(&mut *conn)
338                .await
339                .map_err(|e| KgError::Database(e.to_string()))?;
340
341            sqlx::query("DELETE FROM entities WHERE id = ?")
342                .bind(id)
343                .execute(&mut *conn)
344                .await
345                .map_err(|e| KgError::Database(e.to_string()))?;
346
347            Ok::<usize, KgError>(relation_count.0 as usize)
348        };
349
350        match outcome {
351            Ok(count) => {
352                txn.commit().await.map_err(|e| KgError::Database(e.to_string()))?;
353                Ok(count)
354            }
355            Err(e) => {
356                let _ = txn.rollback().await;
357                Err(e)
358            }
359        }
360    }
361
362    async fn get_entities_batch(&self, ids: &[&str]) -> Result<Vec<Entity>, KgError> {
363        if ids.is_empty() {
364            return Ok(vec![]);
365        }
366        let placeholders: Vec<String> = ids.iter().map(|_| "?".to_string()).collect();
367        let sql = format!(
368            "SELECT id, name, entity_type, attributes_json, description, created_at, updated_at,
369                    version, source, tags_json, aliases_json
370             FROM entities WHERE id IN ({})",
371            placeholders.join(",")
372        );
373
374        let mut query = sqlx::query_as::<_, EntityRow>(&sql);
375        for id in ids {
376            query = query.bind(id);
377        }
378
379        let rows: Vec<EntityRow> = query
380            .fetch_all(&self.pool)
381            .await
382            .map_err(|e| KgError::Database(e.to_string()))?;
383
384        Ok(rows.into_iter().map(|r| r.into()).collect())
385    }
386
387    // === Relation Operations ===
388
389    async fn upsert_relation(&self, relation: Relation) -> Result<UpsertResult, KgError> {
390        let existing: Option<RelationRow> = sqlx::query_as(
391            "SELECT id, source_id, target_id, relation_type, properties_json, confidence,
392                    provenance_json, valid_from, valid_to, created_at, weight
393             FROM relations WHERE id = ?",
394        )
395        .bind(&relation.id)
396        .fetch_optional(&self.pool)
397        .await
398        .map_err(|e| KgError::Database(e.to_string()))?;
399
400        let props_json = serde_json::to_string(&relation.properties)
401            .map_err(|e| KgError::Serialization(e.to_string()))?;
402        let provenance_json = relation
403            .provenance
404            .as_ref()
405            .map(|p| serde_json::to_string(p))
406            .transpose()
407            .map_err(|e| KgError::Serialization(e.to_string()))?;
408
409        if existing.is_some() {
410            sqlx::query(
411                "UPDATE relations SET relation_type=?, properties_json=?, confidence=?,
412                 provenance_json=?, valid_from=?, valid_to=?, weight=?
413                 WHERE id=?",
414            )
415            .bind(&relation.relation_type)
416            .bind(&props_json)
417            .bind(relation.confidence.as_f32())
418            .bind(&provenance_json)
419            .bind(relation.valid_from.map(|v| v as i64))
420            .bind(relation.valid_to.map(|v| v as i64))
421            .bind(relation.weight)
422            .bind(&relation.id)
423            .execute(&self.pool)
424            .await
425            .map_err(|e| KgError::Database(e.to_string()))?;
426
427            Ok(UpsertResult::Updated {
428                changed_fields: vec!["relation_type".into()],
429                conflicts: vec![],
430            })
431        } else {
432            sqlx::query(
433                "INSERT INTO relations (id, source_id, target_id, relation_type, properties_json,
434                 confidence, provenance_json, valid_from, valid_to, created_at, weight)
435                 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
436            )
437            .bind(&relation.id)
438            .bind(&relation.source_id)
439            .bind(&relation.target_id)
440            .bind(&relation.relation_type)
441            .bind(&props_json)
442            .bind(relation.confidence.as_f32())
443            .bind(&provenance_json)
444            .bind(relation.valid_from.map(|v| v as i64))
445            .bind(relation.valid_to.map(|v| v as i64))
446            .bind(relation.created_at as i64)
447            .bind(relation.weight)
448            .execute(&self.pool)
449            .await
450            .map_err(|e| KgError::Database(e.to_string()))?;
451
452            Ok(UpsertResult::Created)
453        }
454    }
455
456    async fn get_relations(&self, entity_id: &str) -> Result<Vec<Relation>, KgError> {
457        let rows: Vec<RelationRow> = sqlx::query_as(
458            "SELECT id, source_id, target_id, relation_type, properties_json, confidence,
459                    provenance_json, valid_from, valid_to, created_at, weight
460             FROM relations WHERE source_id = ? OR target_id = ?",
461        )
462        .bind(entity_id)
463        .bind(entity_id)
464        .fetch_all(&self.pool)
465        .await
466        .map_err(|e| KgError::Database(e.to_string()))?;
467
468        Ok(rows.into_iter().map(|r| r.into()).collect())
469    }
470
471    async fn query_relations(&self, query: &RelationQuery) -> Result<Vec<Relation>, KgError> {
472        let mut sql = String::from(
473            "SELECT id, source_id, target_id, relation_type, properties_json, confidence,
474                    provenance_json, valid_from, valid_to, created_at, weight
475             FROM relations WHERE 1=1",
476        );
477        let mut params: Vec<String> = Vec::new();
478
479        if let Some(ref sid) = query.source_id {
480            params.push(sid.clone());
481            sql.push_str(" AND source_id = ?");
482        }
483        if let Some(ref tid) = query.target_id {
484            params.push(tid.clone());
485            sql.push_str(" AND target_id = ?");
486        }
487        if let Some(ref eid) = query.entity_id {
488            params.push(eid.clone());
489            params.push(eid.clone());
490            sql.push_str(" AND (source_id = ? OR target_id = ?)");
491        }
492        if let Some(ref rt) = query.relation_type {
493            params.push(rt.clone());
494            sql.push_str(" AND relation_type = ?");
495        }
496        if let Some(ref rts) = query.relation_types {
497            if !rts.is_empty() {
498                let placeholders: Vec<String> = rts.iter().map(|_| "?".to_string()).collect();
499                sql.push_str(&format!(" AND relation_type IN ({})", placeholders.join(",")));
500                params.extend(rts.iter().cloned());
501            }
502        }
503        if let Some(ref min_conf) = query.min_confidence {
504            sql.push_str(" AND confidence >= ?");
505            params.push(min_conf.as_f32().to_string());
506        }
507        if let Some(valid_at) = query.valid_at {
508            sql.push_str(" AND (valid_from IS NULL OR valid_from <= ?) AND (valid_to IS NULL OR valid_to >= ?)");
509            params.push(valid_at.to_string());
510            params.push(valid_at.to_string());
511        }
512
513        sql.push_str(" LIMIT ? OFFSET ?");
514
515        let mut fetch_query = sqlx::query_as::<_, RelationRow>(&sql);
516        for p in &params {
517            fetch_query = fetch_query.bind(p);
518        }
519        fetch_query = fetch_query
520            .bind(query.page.limit as i64)
521            .bind(query.page.offset as i64);
522
523        let rows: Vec<RelationRow> = fetch_query
524            .fetch_all(&self.pool)
525            .await
526            .map_err(|e| KgError::Database(e.to_string()))?;
527
528        Ok(rows.into_iter().map(|r| r.into()).collect())
529    }
530
531    async fn delete_relation(&self, id: &str) -> Result<(), KgError> {
532        let result = sqlx::query("DELETE FROM relations WHERE id = ?")
533            .bind(id)
534            .execute(&self.pool)
535            .await
536            .map_err(|e| KgError::Database(e.to_string()))?;
537
538        if result.rows_affected() == 0 {
539            return Err(KgError::RelationNotFound(id.to_string()));
540        }
541        Ok(())
542    }
543
544    // === Graph Traversal ===
545
546    async fn get_neighbors(&self, entity_id: &str, depth: u32) -> Result<SubGraph, KgError> {
547        if depth > self.max_bfs_depth {
548            return Err(KgError::MaxDepthExceeded {
549                depth,
550                max: self.max_bfs_depth,
551            });
552        }
553
554        let center = self
555            .get_entity(entity_id)
556            .await?
557            .ok_or_else(|| KgError::EntityNotFound(entity_id.to_string()))?;
558
559        let mut visited_entities: HashMap<String, Entity> = HashMap::new();
560        let mut visited_relations: Vec<Relation> = Vec::new();
561        let mut queue: VecDeque<(String, u32)> = VecDeque::new();
562
563        visited_entities.insert(entity_id.to_string(), center.clone());
564        queue.push_back((entity_id.to_string(), 0));
565
566        while let Some((current_id, current_depth)) = queue.pop_front() {
567            if current_depth >= depth {
568                continue;
569            }
570
571            // Get all relations for the current entity
572            let relations = self.get_relations(&current_id).await?;
573            for rel in relations {
574                let neighbor_id = if rel.source_id == current_id {
575                    rel.target_id.clone()
576                } else {
577                    rel.source_id.clone()
578                };
579
580                visited_relations.push(rel);
581
582                if !visited_entities.contains_key(&neighbor_id) {
583                    if let Some(entity) = self.get_entity(&neighbor_id).await? {
584                        visited_entities.insert(neighbor_id.clone(), entity);
585                        queue.push_back((neighbor_id, current_depth + 1));
586                    }
587                }
588            }
589        }
590
591        let entities: Vec<Entity> = visited_entities
592            .into_iter()
593            .filter(|(id, _)| id != entity_id)
594            .map(|(_, e)| e)
595            .collect();
596
597        Ok(SubGraph {
598            center,
599            entities,
600            relations: visited_relations,
601        })
602    }
603
604    async fn shortest_path(
605        &self,
606        from: &str,
607        to: &str,
608    ) -> Result<Option<Vec<PathStep>>, KgError> {
609        if from == to {
610            return Ok(Some(vec![]));
611        }
612
613        // Load all entities and relations into memory for path finding
614        let entity_rows: Vec<EntityRow> = sqlx::query_as(
615            "SELECT id, name, entity_type, attributes_json, description, created_at, updated_at,
616                    version, source, tags_json, aliases_json FROM entities",
617        )
618        .fetch_all(&self.pool)
619        .await
620        .map_err(|e| KgError::Database(e.to_string()))?;
621
622        let entities: HashMap<String, Entity> =
623            entity_rows.into_iter().map(|r| (r.id.clone(), r.into())).collect();
624
625        let relation_rows: Vec<RelationRow> = sqlx::query_as(
626            "SELECT id, source_id, target_id, relation_type, properties_json, confidence,
627                    provenance_json, valid_from, valid_to, created_at, weight FROM relations",
628        )
629        .fetch_all(&self.pool)
630        .await
631        .map_err(|e| KgError::Database(e.to_string()))?;
632
633        let relations: Vec<Relation> = relation_rows.into_iter().map(|r| r.into()).collect();
634
635        // Build adjacency lists
636        let mut adj: HashMap<String, Vec<(String, Relation)>> = HashMap::new();
637        for rel in &relations {
638            adj.entry(rel.source_id.clone())
639                .or_default()
640                .push((rel.target_id.clone(), rel.clone()));
641            adj.entry(rel.target_id.clone())
642                .or_default()
643                .push((rel.source_id.clone(), rel.clone()));
644        }
645
646        let mut dist: HashMap<String, f32> = HashMap::new();
647        let mut prev: HashMap<String, (String, Relation)> = HashMap::new();
648        let initial_dist = f32::MAX;
649
650        for id in entities.keys() {
651            dist.insert(id.clone(), initial_dist);
652        }
653        dist.insert(from.to_string(), 0.0);
654
655        let mut queue: BinaryHeap<PathCost> = BinaryHeap::new();
656        queue.push(PathCost(0.0, from.to_string()));
657
658        while let Some(PathCost(_d, u)) = queue.pop() {
659            if let Some(neighbors) = adj.get(&u) {
660                for (v, rel) in neighbors {
661                    let weight = self.weight_strategy.relation_cost(rel);
662                    let alt = dist.get(&u).copied().unwrap_or(initial_dist) + weight;
663                    if alt < dist.get(v).copied().unwrap_or(initial_dist) {
664                        dist.insert(v.clone(), alt);
665                        prev.insert(v.clone(), (u.clone(), rel.clone()));
666                        queue.push(PathCost(alt, v.clone()));
667                    }
668                }
669            }
670        }
671
672        if !prev.contains_key(to) && from != to {
673            return Ok(None);
674        }
675
676        // Reconstruct path
677        let mut path = Vec::new();
678        let mut current = to.to_string();
679        while current != from {
680            if let Some((prev_node, rel)) = prev.get(&current) {
681                let entity = entities.get(&current).cloned().unwrap();
682                path.push(PathStep { entity, relation: rel.clone() });
683                current = prev_node.clone();
684            } else {
685                break;
686            }
687        }
688        // Add the starting entity
689        path.reverse();
690
691        Ok(Some(path))
692    }
693
694    async fn all_paths(
695        &self,
696        from: &str,
697        to: &str,
698        max_depth: u32,
699    ) -> Result<Vec<Vec<PathStep>>, KgError> {
700        if max_depth > self.max_path_search {
701            return Err(KgError::MaxDepthExceeded {
702                depth: max_depth,
703                max: self.max_path_search,
704            });
705        }
706
707        let entity_rows: Vec<EntityRow> = sqlx::query_as(
708            "SELECT id, name, entity_type, attributes_json, description, created_at, updated_at,
709                    version, source, tags_json, aliases_json FROM entities",
710        )
711        .fetch_all(&self.pool)
712        .await
713        .map_err(|e| KgError::Database(e.to_string()))?;
714
715        let entities: HashMap<String, Entity> =
716            entity_rows.into_iter().map(|r| (r.id.clone(), r.into())).collect();
717
718        let relation_rows: Vec<RelationRow> = sqlx::query_as(
719            "SELECT id, source_id, target_id, relation_type, properties_json, confidence,
720                    provenance_json, valid_from, valid_to, created_at, weight FROM relations",
721        )
722        .fetch_all(&self.pool)
723        .await
724        .map_err(|e| KgError::Database(e.to_string()))?;
725
726        let relations: Vec<Relation> = relation_rows.into_iter().map(|r| r.into()).collect();
727
728        // Build adjacency lists
729        let mut adj: HashMap<String, Vec<(String, Relation)>> = HashMap::new();
730        for rel in &relations {
731            adj.entry(rel.source_id.clone())
732                .or_default()
733                .push((rel.target_id.clone(), rel.clone()));
734            adj.entry(rel.target_id.clone())
735                .or_default()
736                .push((rel.source_id.clone(), rel.clone()));
737        }
738
739        let mut all_paths: Vec<Vec<PathStep>> = Vec::new();
740        let mut visited: HashSet<String> = HashSet::new();
741        let mut current_path: Vec<PathStep> = Vec::new();
742
743        dfs_all_paths(
744            from,
745            to,
746            max_depth,
747            &entities,
748            &adj,
749            &mut visited,
750            &mut current_path,
751            &mut all_paths,
752        );
753
754        all_paths.sort_by(|a, b| {
755            let a_cost: f32 = a.iter().map(|step| self.weight_strategy.relation_cost(&step.relation)).sum();
756            let b_cost: f32 = b.iter().map(|step| self.weight_strategy.relation_cost(&step.relation)).sum();
757            a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
758        });
759
760        Ok(all_paths)
761    }
762
763    // === Batch Operations ===
764
765    async fn batch_import(
766        &self,
767        entities: Vec<Entity>,
768        relations: Vec<Relation>,
769    ) -> Result<ImportResult, KgError> {
770        let mut txn = self.pool.begin().await.map_err(|e| KgError::Database(e.to_string()))?;
771
772        // Scoped to allow fallback to rollback after conn is dropped
773        let outcome = {
774            let conn = std::ops::DerefMut::deref_mut(&mut txn);
775            let mut result = ImportResult::default();
776
777            for entity in &entities {
778                match batch_upsert_entity(conn, entity).await {
779                    Ok(UpsertResult::Created) => result.entities_created += 1,
780                    Ok(UpsertResult::Updated { conflicts, .. }) => {
781                        result.entities_updated += 1;
782                        result.conflicts.extend(conflicts);
783                    }
784                    Ok(UpsertResult::Unchanged) => result.entities_skipped += 1,
785                    Err(e) => return Err(e),
786                }
787            }
788
789            for relation in &relations {
790                match batch_upsert_relation(conn, relation).await {
791                    Ok(UpsertResult::Created) => result.relations_created += 1,
792                    Ok(UpsertResult::Updated { .. }) => result.relations_updated += 1,
793                    Ok(UpsertResult::Unchanged) => {}
794                    Err(e) => return Err(e),
795                }
796            }
797
798            Ok(result)
799        };
800
801        match outcome {
802            Ok(result) => {
803                txn.commit().await.map_err(|e| KgError::Database(e.to_string()))?;
804                info!(
805                    entities_created = %result.entities_created,
806                    entities_updated = %result.entities_updated,
807                    relations_created = %result.relations_created,
808                    "batch import completed"
809                );
810                Ok(result)
811            }
812            Err(e) => {
813                let _ = txn.rollback().await;
814                Err(e)
815            }
816        }
817    }
818
819    // === Consistency ===
820
821    async fn check_consistency(&self) -> Result<Vec<ConsistencyIssue>, KgError> {
822        let mut issues = Vec::new();
823
824        // Check 1: Orphan relations
825        let orphans: Vec<OrphanRelationRow> = sqlx::query_as(
826            "SELECT r.id, r.source_id, r.target_id
827             FROM relations r
828             LEFT JOIN entities e1 ON r.source_id = e1.id
829             LEFT JOIN entities e2 ON r.target_id = e2.id
830             WHERE e1.id IS NULL OR e2.id IS NULL",
831        )
832        .fetch_all(&self.pool)
833        .await
834        .map_err(|e| KgError::Database(e.to_string()))?;
835
836        for o in orphans {
837            issues.push(ConsistencyIssue {
838                severity: IssueSeverity::Error,
839                issue_type: ConsistencyIssueType::OrphanRelation,
840                description: format!("Relation {} references a non-existent entity", o.id),
841                related_entities: vec![o.source_id, o.target_id],
842                related_relations: vec![o.id],
843            });
844        }
845
846        // Check 2: Self-referencing
847        let self_refs: Vec<RelationRow> = sqlx::query_as(
848            "SELECT id, source_id, target_id, relation_type, properties_json, confidence,
849                    provenance_json, valid_from, valid_to, created_at, weight
850             FROM relations WHERE source_id = target_id",
851        )
852        .fetch_all(&self.pool)
853        .await
854        .map_err(|e| KgError::Database(e.to_string()))?;
855
856        for rel in self_refs {
857            issues.push(ConsistencyIssue {
858                severity: IssueSeverity::Warning,
859                issue_type: ConsistencyIssueType::SelfReferencing,
860                description: format!("Relation {} self-references entity {}", rel.id, rel.source_id),
861                related_entities: vec![rel.source_id],
862                related_relations: vec![rel.id],
863            });
864        }
865
866        // Check 3: Orphan entities
867        let orphan_entities: Vec<(String, String)> = sqlx::query_as(
868            "SELECT e.id, e.name FROM entities e
869             WHERE e.id NOT IN (SELECT source_id FROM relations)
870               AND e.id NOT IN (SELECT target_id FROM relations)",
871        )
872        .fetch_all(&self.pool)
873        .await
874        .map_err(|e| KgError::Database(e.to_string()))?;
875
876        for (id, name) in orphan_entities {
877            issues.push(ConsistencyIssue {
878                severity: IssueSeverity::Info,
879                issue_type: ConsistencyIssueType::OrphanEntity,
880                description: format!("Entity {} ({}) has no relations", name, id),
881                related_entities: vec![id],
882                related_relations: vec![],
883            });
884        }
885
886        // Check 4: Expired relations
887        let now = current_epoch_ms();
888        let expired: Vec<RelationRow> = sqlx::query_as(
889            "SELECT id, source_id, target_id, relation_type, properties_json, confidence,
890                    provenance_json, valid_from, valid_to, created_at, weight
891             FROM relations WHERE valid_to IS NOT NULL AND valid_to < ?",
892        )
893        .bind(now as i64)
894        .fetch_all(&self.pool)
895        .await
896        .map_err(|e| KgError::Database(e.to_string()))?;
897
898        for rel in expired {
899            issues.push(ConsistencyIssue {
900                severity: IssueSeverity::Warning,
901                issue_type: ConsistencyIssueType::ExpiredRelation,
902                description: format!("Relation {} has expired (valid_to < now)", rel.id),
903                related_entities: vec![rel.source_id, rel.target_id],
904                related_relations: vec![rel.id],
905            });
906        }
907
908        Ok(issues)
909    }
910
911    // === Statistics ===
912
913    async fn stats(&self) -> Result<GraphStats, KgError> {
914        let total_entities: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM entities")
915            .fetch_one(&self.pool)
916            .await
917            .map_err(|e| KgError::Database(e.to_string()))?;
918
919        let total_relations: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM relations")
920            .fetch_one(&self.pool)
921            .await
922            .map_err(|e| KgError::Database(e.to_string()))?;
923
924        let entity_types: Vec<(String, i64)> = sqlx::query_as(
925            "SELECT entity_type, COUNT(*) as cnt FROM entities GROUP BY entity_type",
926        )
927        .fetch_all(&self.pool)
928        .await
929        .map_err(|e| KgError::Database(e.to_string()))?;
930
931        let relation_types: Vec<(String, i64)> = sqlx::query_as(
932            "SELECT relation_type, COUNT(*) as cnt FROM relations GROUP BY relation_type",
933        )
934        .fetch_all(&self.pool)
935        .await
936        .map_err(|e| KgError::Database(e.to_string()))?;
937
938        // Calculate degrees
939        let degrees: Vec<(i64,)> = sqlx::query_as(
940            "SELECT COUNT(*) FROM relations GROUP BY source_id
941             UNION ALL SELECT COUNT(*) FROM relations GROUP BY target_id",
942        )
943        .fetch_all(&self.pool)
944        .await
945        .map_err(|e| KgError::Database(e.to_string()))?;
946
947        let degree_values: Vec<usize> = degrees.into_iter().map(|d| d.0 as usize).collect();
948        let avg_degree = if degree_values.is_empty() {
949            0.0
950        } else {
951            degree_values.iter().sum::<usize>() as f64 / degree_values.len() as f64
952        };
953        let max_degree = degree_values.iter().max().copied().unwrap_or(0);
954
955        // Orphan entities
956        let orphan_entities: (i64,) = sqlx::query_as(
957            "SELECT COUNT(*) FROM entities e
958             WHERE e.id NOT IN (SELECT source_id FROM relations)
959               AND e.id NOT IN (SELECT target_id FROM relations)",
960        )
961        .fetch_one(&self.pool)
962        .await
963        .map_err(|e| KgError::Database(e.to_string()))?;
964
965        // DB size
966        let db_size: (i64,) = sqlx::query_as(
967            "SELECT COALESCE(SUM(pgsize), 0) FROM dbstat",
968        )
969        .fetch_one(&self.pool)
970        .await
971        .map_err(|e| KgError::Database(e.to_string()))?;
972
973        Ok(GraphStats {
974            total_entities: total_entities.0 as usize,
975            total_relations: total_relations.0 as usize,
976            entity_types: entity_types.into_iter().map(|(k, v)| (k, v as usize)).collect(),
977            relation_types: relation_types.into_iter().map(|(k, v)| (k, v as usize)).collect(),
978            avg_degree,
979            max_degree,
980            orphan_entities: orphan_entities.0 as usize,
981            db_size_bytes: db_size.0 as u64,
982        })
983    }
984}
985
986// === DFS helper ===
987
988#[allow(clippy::too_many_arguments)]
989fn dfs_all_paths(
990    current: &str,
991    target: &str,
992    max_depth: u32,
993    entities: &HashMap<String, Entity>,
994    adj: &HashMap<String, Vec<(String, Relation)>>,
995    visited: &mut HashSet<String>,
996    current_path: &mut Vec<PathStep>,
997    all_paths: &mut Vec<Vec<PathStep>>,
998) {
999    if current == target {
1000        all_paths.push(current_path.clone());
1001        return;
1002    }
1003    if current_path.len() >= max_depth as usize {
1004        return;
1005    }
1006    visited.insert(current.to_string());
1007
1008    if let Some(neighbors) = adj.get(current) {
1009        for (neighbor, rel) in neighbors {
1010            if visited.contains(neighbor.as_str()) {
1011                continue;
1012            }
1013            if let Some(entity) = entities.get(neighbor).cloned() {
1014                current_path.push(PathStep {
1015                    entity,
1016                    relation: rel.clone(),
1017                });
1018                dfs_all_paths(
1019                    neighbor, target, max_depth, entities, adj,
1020                    visited, current_path, all_paths,
1021                );
1022                current_path.pop();
1023            }
1024        }
1025    }
1026
1027    visited.remove(current);
1028}
1029
1030// === Batch helpers (operate on a &mut SqliteConnection within a transaction) ===
1031
1032async fn batch_upsert_entity(
1033    conn: &mut sqlx::SqliteConnection,
1034    entity: &Entity,
1035) -> Result<UpsertResult, KgError> {
1036    let existing: Option<EntityRow> = sqlx::query_as(
1037        "SELECT id, name, entity_type, attributes_json, description, created_at, updated_at,
1038                version, source, tags_json, aliases_json
1039         FROM entities WHERE id = ?",
1040    )
1041    .bind(&entity.id)
1042    .fetch_optional(&mut *conn)
1043    .await
1044    .map_err(|e| KgError::Database(e.to_string()))?;
1045
1046    let attrs_json =
1047        serde_json::to_string(&entity.attributes).map_err(|e| KgError::Serialization(e.to_string()))?;
1048    let tags_json =
1049        serde_json::to_string(&entity.tags).map_err(|e| KgError::Serialization(e.to_string()))?;
1050    let aliases_json =
1051        serde_json::to_string(&entity.aliases).map_err(|e| KgError::Serialization(e.to_string()))?;
1052    let entity_type = entity.entity_type.as_str();
1053
1054    if let Some(row) = existing {
1055        let mut changed = Vec::new();
1056        let conflicts = Vec::new();
1057
1058        if row.name != entity.name {
1059            changed.push("name".into());
1060        }
1061        if row.entity_type != entity_type {
1062            changed.push("entity_type".into());
1063        }
1064
1065        if changed.is_empty() {
1066            return Ok(UpsertResult::Unchanged);
1067        }
1068
1069        sqlx::query(
1070            "UPDATE entities SET name=?, entity_type=?, attributes_json=?, description=?,
1071             updated_at=?, version=version+1, source=?, tags_json=?, aliases_json=?
1072             WHERE id=?",
1073        )
1074        .bind(&entity.name)
1075        .bind(&entity_type)
1076        .bind(&attrs_json)
1077        .bind(&entity.description)
1078        .bind(current_epoch_ms() as i64)
1079        .bind(&entity.source)
1080        .bind(&tags_json)
1081        .bind(&aliases_json)
1082        .bind(&entity.id)
1083        .execute(&mut *conn)
1084        .await
1085        .map_err(|e| KgError::Database(e.to_string()))?;
1086
1087        Ok(UpsertResult::Updated { changed_fields: changed, conflicts })
1088    } else {
1089        sqlx::query(
1090            "INSERT INTO entities (id, name, entity_type, attributes_json, description,
1091             created_at, updated_at, version, source, tags_json, aliases_json)
1092             VALUES (?, ?, ?, ?, ?, ?, ?, 1, ?, ?, ?)",
1093        )
1094        .bind(&entity.id)
1095        .bind(&entity.name)
1096        .bind(&entity_type)
1097        .bind(&attrs_json)
1098        .bind(&entity.description)
1099        .bind(entity.created_at as i64)
1100        .bind(entity.updated_at as i64)
1101        .bind(&entity.source)
1102        .bind(&tags_json)
1103        .bind(&aliases_json)
1104        .execute(&mut *conn)
1105        .await
1106        .map_err(|e| KgError::Database(e.to_string()))?;
1107
1108        Ok(UpsertResult::Created)
1109    }
1110}
1111
1112async fn batch_upsert_relation(
1113    conn: &mut sqlx::SqliteConnection,
1114    relation: &Relation,
1115) -> Result<UpsertResult, KgError> {
1116    let existing: Option<RelationRow> = sqlx::query_as(
1117        "SELECT id, source_id, target_id, relation_type, properties_json, confidence,
1118                provenance_json, valid_from, valid_to, created_at, weight
1119         FROM relations WHERE id = ?",
1120    )
1121    .bind(&relation.id)
1122    .fetch_optional(&mut *conn)
1123    .await
1124    .map_err(|e| KgError::Database(e.to_string()))?;
1125
1126    let props_json = serde_json::to_string(&relation.properties)
1127        .map_err(|e| KgError::Serialization(e.to_string()))?;
1128    let provenance_json = relation
1129        .provenance
1130        .as_ref()
1131        .map(|p| serde_json::to_string(p))
1132        .transpose()
1133        .map_err(|e| KgError::Serialization(e.to_string()))?;
1134
1135    if existing.is_some() {
1136        sqlx::query(
1137            "UPDATE relations SET relation_type=?, properties_json=?, confidence=?,
1138             provenance_json=?, valid_from=?, valid_to=?, weight=?
1139             WHERE id=?",
1140        )
1141        .bind(&relation.relation_type)
1142        .bind(&props_json)
1143        .bind(relation.confidence.as_f32())
1144        .bind(&provenance_json)
1145        .bind(relation.valid_from.map(|v| v as i64))
1146        .bind(relation.valid_to.map(|v| v as i64))
1147        .bind(relation.weight)
1148        .bind(&relation.id)
1149        .execute(&mut *conn)
1150        .await
1151        .map_err(|e| KgError::Database(e.to_string()))?;
1152
1153        Ok(UpsertResult::Updated {
1154            changed_fields: vec!["relation_type".into()],
1155            conflicts: vec![],
1156        })
1157    } else {
1158        sqlx::query(
1159            "INSERT INTO relations (id, source_id, target_id, relation_type, properties_json,
1160             confidence, provenance_json, valid_from, valid_to, created_at, weight)
1161             VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
1162        )
1163        .bind(&relation.id)
1164        .bind(&relation.source_id)
1165        .bind(&relation.target_id)
1166        .bind(&relation.relation_type)
1167        .bind(&props_json)
1168        .bind(relation.confidence.as_f32())
1169        .bind(&provenance_json)
1170        .bind(relation.valid_from.map(|v| v as i64))
1171        .bind(relation.valid_to.map(|v| v as i64))
1172        .bind(relation.created_at as i64)
1173        .bind(relation.weight)
1174        .execute(&mut *conn)
1175        .await
1176        .map_err(|e| KgError::Database(e.to_string()))?;
1177
1178        Ok(UpsertResult::Created)
1179    }
1180}
1181
1182// === Row types ===
1183
1184#[derive(Debug, sqlx::FromRow)]
1185struct EntityRow {
1186    id: String,
1187    name: String,
1188    entity_type: String,
1189    attributes_json: String,
1190    description: Option<String>,
1191    created_at: i64,
1192    updated_at: i64,
1193    version: i64,
1194    source: Option<String>,
1195    tags_json: String,
1196    aliases_json: String,
1197}
1198
1199impl From<EntityRow> for Entity {
1200    fn from(r: EntityRow) -> Self {
1201        let attributes: HashMap<String, AttributeValue> =
1202            serde_json::from_str(&r.attributes_json).unwrap_or_default();
1203        let tags: Vec<String> = serde_json::from_str(&r.tags_json).unwrap_or_default();
1204        let aliases: Vec<String> = serde_json::from_str(&r.aliases_json).unwrap_or_default();
1205
1206        Self {
1207            id: r.id,
1208            name: r.name,
1209            entity_type: EntityType::from_str(&r.entity_type),
1210            attributes,
1211            description: r.description,
1212            created_at: r.created_at as u64,
1213            updated_at: r.updated_at as u64,
1214            version: r.version as u64,
1215            source: r.source,
1216            tags,
1217            aliases,
1218        }
1219    }
1220}
1221
1222#[derive(Debug, sqlx::FromRow)]
1223struct RelationRow {
1224    id: String,
1225    source_id: String,
1226    target_id: String,
1227    relation_type: String,
1228    properties_json: String,
1229    confidence: f32,
1230    provenance_json: Option<String>,
1231    valid_from: Option<i64>,
1232    valid_to: Option<i64>,
1233    created_at: i64,
1234    weight: Option<f32>,
1235}
1236
1237impl From<RelationRow> for Relation {
1238    fn from(r: RelationRow) -> Self {
1239        let properties: HashMap<String, String> =
1240            serde_json::from_str(&r.properties_json).unwrap_or_default();
1241        let provenance: Option<Provenance> = r
1242            .provenance_json
1243            .and_then(|j| serde_json::from_str(&j).ok());
1244
1245        Self {
1246            id: r.id,
1247            source_id: r.source_id,
1248            target_id: r.target_id,
1249            relation_type: r.relation_type,
1250            properties,
1251            confidence: Confidence::from_f32(r.confidence),
1252            provenance,
1253            valid_from: r.valid_from.map(|v| v as u64),
1254            valid_to: r.valid_to.map(|v| v as u64),
1255            created_at: r.created_at as u64,
1256            weight: r.weight,
1257        }
1258    }
1259}
1260
1261#[derive(Debug, sqlx::FromRow)]
1262struct OrphanRelationRow {
1263    id: String,
1264    source_id: String,
1265    target_id: String,
1266}
1267
1268// === Utility ===
1269
1270fn current_epoch_ms() -> u64 {
1271    std::time::SystemTime::now()
1272        .duration_since(std::time::UNIX_EPOCH)
1273        .unwrap_or_default()
1274        .as_millis() as u64
1275}