1use sqlite_loadable::{
6 define_scalar_function, ext::sqlite3ext_result_text, prelude::*, Error, FunctionFlags,
7};
8use std::ffi::CString;
9
10fn result_text(context: *mut sqlite3_context, text: &str) {
12 let cstr = CString::new(text).unwrap();
13 unsafe {
14 sqlite3ext_result_text(
15 context,
16 cstr.as_ptr(),
17 cstr.as_bytes().len() as i32,
18 Some(std::mem::transmute::<
19 i64,
20 unsafe extern "C" fn(*mut std::ffi::c_void),
21 >(-1i64)),
22 );
23 }
24}
25
26pub fn kg_version(
28 context: *mut sqlite3_context,
29 _values: &[*mut sqlite3_value],
30) -> Result<(), Error> {
31 result_text(context, env!("CARGO_PKG_VERSION"));
32 Ok(())
33}
34
35pub fn kg_stats(
37 context: *mut sqlite3_context,
38 _values: &[*mut sqlite3_value],
39) -> Result<(), Error> {
40 result_text(
43 context,
44 "{\"status\": \"Extension loaded - use KnowledgeGraph API for full stats\"}",
45 );
46 Ok(())
47}
48
49pub fn kg_pagerank(
53 context: *mut sqlite3_context,
54 values: &[*mut sqlite3_value],
55) -> Result<(), Error> {
56 let damping = if !values.is_empty() {
58 unsafe { sqlite_loadable::ext::sqlite3ext_value_double(values[0]) }
59 } else {
60 0.85
61 };
62
63 let max_iterations = if values.len() >= 2 {
65 unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[1]) as usize }
66 } else {
67 100
68 };
69
70 let tolerance = if values.len() >= 3 {
72 unsafe { sqlite_loadable::ext::sqlite3ext_value_double(values[2]) }
73 } else {
74 1e-6
75 };
76
77 let result = format!(
79 "{{\"algorithm\": \"pagerank\", \"damping\": {}, \"max_iterations\": {}, \"tolerance\": {}, \"note\": \"Use KnowledgeGraph::kg_pagerank() for full computation\"}}",
80 damping, max_iterations, tolerance
81 );
82 result_text(context, &result);
83 Ok(())
84}
85
86pub fn kg_louvain(
89 context: *mut sqlite3_context,
90 _values: &[*mut sqlite3_value],
91) -> Result<(), Error> {
92 result_text(context, "{\"algorithm\": \"louvain\", \"note\": \"Use KnowledgeGraph::kg_louvain() for full computation\"}");
93 Ok(())
94}
95
96pub fn kg_bfs(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<(), Error> {
100 if values.is_empty() {
101 return Err(Error::new_message(
102 "kg_bfs requires at least 1 argument: start_id",
103 ));
104 }
105
106 let start_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[0]) };
107 let max_depth = if values.len() >= 2 {
108 unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[1]) as u32 }
109 } else {
110 3
111 };
112
113 let result = format!(
114 "{{\"algorithm\": \"bfs\", \"start_id\": {}, \"max_depth\": {}, \"note\": \"Use KnowledgeGraph::kg_bfs_traversal() for full computation\"}}",
115 start_id, max_depth
116 );
117 result_text(context, &result);
118 Ok(())
119}
120
121pub fn kg_shortest_path(
125 context: *mut sqlite3_context,
126 values: &[*mut sqlite3_value],
127) -> Result<(), Error> {
128 if values.len() < 2 {
129 return Err(Error::new_message(
130 "kg_shortest_path requires at least 2 arguments: from_id, to_id",
131 ));
132 }
133
134 let from_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[0]) };
135 let to_id = unsafe { sqlite_loadable::ext::sqlite3ext_value_int64(values[1]) };
136 let max_depth = if values.len() >= 3 {
137 unsafe { sqlite_loadable::ext::sqlite3ext_value_int(values[2]) as u32 }
138 } else {
139 10
140 };
141
142 let result = format!(
143 "{{\"algorithm\": \"shortest_path\", \"from_id\": {}, \"to_id\": {}, \"max_depth\": {}, \"note\": \"Use KnowledgeGraph::kg_shortest_path() for full computation\"}}",
144 from_id, to_id, max_depth
145 );
146 result_text(context, &result);
147 Ok(())
148}
149
150pub fn kg_connected_components(
153 context: *mut sqlite3_context,
154 _values: &[*mut sqlite3_value],
155) -> Result<(), Error> {
156 result_text(context, "{\"algorithm\": \"connected_components\", \"note\": \"Use KnowledgeGraph::kg_connected_components() for full computation\"}");
157 Ok(())
158}
159
160fn register_extension_functions(db: *mut sqlite3) -> Result<(), Error> {
162 let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC;
163
164 define_scalar_function(db, "kg_version", 0, kg_version, flags)?;
166 define_scalar_function(db, "kg_stats", 0, kg_stats, flags)?;
167
168 define_scalar_function(db, "kg_pagerank", 0, kg_pagerank, flags)?;
170 define_scalar_function(db, "kg_pagerank", 1, kg_pagerank, flags)?;
171 define_scalar_function(db, "kg_pagerank", 2, kg_pagerank, flags)?;
172 define_scalar_function(db, "kg_pagerank", 3, kg_pagerank, flags)?;
173
174 define_scalar_function(db, "kg_louvain", 0, kg_louvain, flags)?;
175
176 define_scalar_function(db, "kg_bfs", 1, kg_bfs, flags)?;
177 define_scalar_function(db, "kg_bfs", 2, kg_bfs, flags)?;
178
179 define_scalar_function(db, "kg_shortest_path", 2, kg_shortest_path, flags)?;
180 define_scalar_function(db, "kg_shortest_path", 3, kg_shortest_path, flags)?;
181
182 define_scalar_function(
183 db,
184 "kg_connected_components",
185 0,
186 kg_connected_components,
187 flags,
188 )?;
189
190 Ok(())
191}
192
193#[sqlite_entrypoint]
195pub fn sqlite3_sqlite_knowledge_graph_init(db: *mut sqlite3) -> Result<(), Error> {
196 register_extension_functions(db)
197}
198
199#[cfg(test)]
200mod tests {
201 #[test]
202 fn test_kg_version_format() {
203 let version = env!("CARGO_PKG_VERSION");
205 assert!(!version.is_empty());
206 assert!(version.contains('.'));
207 }
208}