Skip to main content

sqlite_knowledge_graph/
extension.rs

1//! SQLite extension entry point using sqlite-loadable
2//!
3//! This module provides the SQLite loadable extension interface.
4
5use sqlite_loadable::{
6    define_scalar_function, ext::sqlite3ext_result_text, prelude::*, Error, FunctionFlags,
7};
8use std::ffi::CString;
9
10/// Helper function to return an integer result
11fn result_int64(context: *mut sqlite3_context, value: i64) {
12    unsafe {
13        sqlite_loadable::ext::sqlite3ext_result_int64(context, value);
14    }
15}
16
17/// Helper function to return text result
18fn result_text(context: *mut sqlite3_context, text: &str) {
19    let cstr = CString::new(text).unwrap();
20    unsafe {
21        sqlite3ext_result_text(
22            context,
23            cstr.as_ptr(),
24            cstr.as_bytes().len() as i32,
25            Some(std::mem::transmute::<
26                i64,
27                unsafe extern "C" fn(*mut std::ffi::c_void),
28            >(-1i64)),
29        );
30    }
31}
32
33/// kg_version() - Returns the extension version
34pub fn kg_version(
35    context: *mut sqlite3_context,
36    _values: &[*mut sqlite3_value],
37) -> Result<(), Error> {
38    result_text(context, env!("CARGO_PKG_VERSION"));
39    Ok(())
40}
41
42/// kg_stats() - Returns graph statistics as JSON
43pub fn kg_stats(
44    context: *mut sqlite3_context,
45    _values: &[*mut sqlite3_value],
46) -> Result<(), Error> {
47    // For now, return a simple message indicating the extension is loaded
48    // Full implementation would require accessing the database connection
49    result_text(
50        context,
51        "{\"status\": \"Extension loaded - use KnowledgeGraph API for full stats\"}",
52    );
53    Ok(())
54}
55
56/// kg_pagerank() - Compute PageRank scores for all entities
57/// Parameters: damping (REAL, default 0.85), max_iterations (INTEGER, default 100), tolerance (REAL, default 1e-6)
58/// Returns JSON with algorithm info
59pub fn kg_pagerank(
60    context: *mut sqlite3_context,
61    values: &[*mut sqlite3_value],
62) -> Result<(), Error> {
63    // Parse optional damping parameter (default 0.85)
64    let damping = if !values.is_empty() {
65        unsafe { sqlite_loadable::ext::sqlite3ext_value_double(values[0]) }
66    } else {
67        0.85
68    };
69
70    // Parse optional max_iterations parameter (default 100)
71    let max_iterations = if values.len() >= 2 {
72        unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[1]) as usize }
73    } else {
74        100
75    };
76
77    // Parse optional tolerance parameter (default 1e-6)
78    let tolerance = if values.len() >= 3 {
79        unsafe { sqlite_loadable::ext::sqlite3ext_value_double(values[2]) }
80    } else {
81        1e-6
82    };
83
84    // Return configuration info - actual computation requires database access
85    let result = format!(
86        "{{\"algorithm\": \"pagerank\", \"damping\": {}, \"max_iterations\": {}, \"tolerance\": {}, \"note\": \"Use KnowledgeGraph::kg_pagerank() for full computation\"}}",
87        damping, max_iterations, tolerance
88    );
89    result_text(context, &result);
90    Ok(())
91}
92
93/// kg_louvain() - Detect communities using Louvain algorithm
94/// Returns JSON with community memberships and modularity score
95pub fn kg_louvain(
96    context: *mut sqlite3_context,
97    _values: &[*mut sqlite3_value],
98) -> Result<(), Error> {
99    result_text(context, "{\"algorithm\": \"louvain\", \"note\": \"Use KnowledgeGraph::kg_louvain() for full computation\"}");
100    Ok(())
101}
102
103/// kg_bfs() - BFS traversal from a starting entity
104/// Parameters: start_id (INTEGER), max_depth (INTEGER, default 3)
105/// Returns JSON array of {entity_id, depth} objects
106pub fn kg_bfs(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<(), Error> {
107    if values.is_empty() {
108        return Err(Error::new_message(
109            "kg_bfs requires at least 1 argument: start_id",
110        ));
111    }
112
113    let start_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[0]) };
114    let max_depth = if values.len() >= 2 {
115        unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[1]) as u32 }
116    } else {
117        3
118    };
119
120    let result = format!(
121        "{{\"algorithm\": \"bfs\", \"start_id\": {}, \"max_depth\": {}, \"note\": \"Use KnowledgeGraph::kg_bfs_traversal() for full computation\"}}",
122        start_id, max_depth
123    );
124    result_text(context, &result);
125    Ok(())
126}
127
128/// kg_shortest_path() - Find shortest path between two entities
129/// Parameters: from_id (INTEGER), to_id (INTEGER), max_depth (INTEGER, default 10)
130/// Returns JSON array of entity IDs representing the path
131pub fn kg_shortest_path(
132    context: *mut sqlite3_context,
133    values: &[*mut sqlite3_value],
134) -> Result<(), Error> {
135    if values.len() < 2 {
136        return Err(Error::new_message(
137            "kg_shortest_path requires at least 2 arguments: from_id, to_id",
138        ));
139    }
140
141    let from_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[0]) };
142    let to_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[1]) };
143    let max_depth = if values.len() >= 3 {
144        unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[2]) as u32 }
145    } else {
146        10
147    };
148
149    let result = format!(
150        "{{\"algorithm\": \"shortest_path\", \"from_id\": {}, \"to_id\": {}, \"max_depth\": {}, \"note\": \"Use KnowledgeGraph::kg_shortest_path() for full computation\"}}",
151        from_id, to_id, max_depth
152    );
153    result_text(context, &result);
154    Ok(())
155}
156
157/// kg_connected_components() - Find connected components in the graph
158/// Returns JSON with component information
159pub fn kg_connected_components(
160    context: *mut sqlite3_context,
161    _values: &[*mut sqlite3_value],
162) -> Result<(), Error> {
163    result_text(context, "{\"algorithm\": \"connected_components\", \"note\": \"Use KnowledgeGraph::kg_connected_components() for full computation\"}");
164    Ok(())
165}
166
167/// kg_bit_count() - Population count of a version validity bitstring
168/// Parameter: x (INTEGER, NULL treated as 0)
169/// Returns the number of set bits (how many versions a row belongs to)
170pub fn kg_bit_count(
171    context: *mut sqlite3_context,
172    values: &[*mut sqlite3_value],
173) -> Result<(), Error> {
174    if values.is_empty() {
175        return Err(Error::new_message("kg_bit_count requires 1 argument"));
176    }
177    // A SQL NULL reads back as 0 here, which has zero set bits — matching the
178    // rusqlite registration's NULL → 0 behavior in src/functions.rs.
179    let val = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[0]) };
180    result_int64(context, val.count_ones() as i64);
181    Ok(())
182}
183
184/// Register functions
185fn register_extension_functions(db: *mut sqlite3) -> Result<(), Error> {
186    let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC;
187
188    // Basic info functions
189    define_scalar_function(db, "kg_version", 0, kg_version, flags)?;
190    define_scalar_function(db, "kg_stats", 0, kg_stats, flags)?;
191
192    // Version bitstring helper (QuaQue versioning)
193    define_scalar_function(db, "kg_bit_count", 1, kg_bit_count, flags)?;
194
195    // Graph algorithm functions with optional parameters
196    define_scalar_function(db, "kg_pagerank", 0, kg_pagerank, flags)?;
197    define_scalar_function(db, "kg_pagerank", 1, kg_pagerank, flags)?;
198    define_scalar_function(db, "kg_pagerank", 2, kg_pagerank, flags)?;
199    define_scalar_function(db, "kg_pagerank", 3, kg_pagerank, flags)?;
200
201    define_scalar_function(db, "kg_louvain", 0, kg_louvain, flags)?;
202
203    define_scalar_function(db, "kg_bfs", 1, kg_bfs, flags)?;
204    define_scalar_function(db, "kg_bfs", 2, kg_bfs, flags)?;
205
206    define_scalar_function(db, "kg_shortest_path", 2, kg_shortest_path, flags)?;
207    define_scalar_function(db, "kg_shortest_path", 3, kg_shortest_path, flags)?;
208
209    define_scalar_function(
210        db,
211        "kg_connected_components",
212        0,
213        kg_connected_components,
214        flags,
215    )?;
216
217    Ok(())
218}
219
220/// Extension entry point
221#[sqlite_entrypoint]
222pub fn sqlite3_sqlite_knowledge_graph_init(db: *mut sqlite3) -> Result<(), Error> {
223    register_extension_functions(db)
224}
225
226#[cfg(test)]
227mod tests {
228    #[test]
229    fn test_kg_version_format() {
230        // Verify version is in expected format (x.y.z)
231        let version = env!("CARGO_PKG_VERSION");
232        assert!(!version.is_empty());
233        assert!(version.contains('.'));
234    }
235}