Skip to main content

sqlite_knowledge_graph/
functions.rs

1//! SQLite custom functions for knowledge graph operations.
2//!
3//! Note: Due to limitations in the rusqlite function context API (cannot access the connection
4//! from scalar functions), the SQL functions are provided as a convenience layer. For full
5//! functionality, use the KnowledgeGraph Rust API directly.
6
7use crate::error::Result;
8
9/// Register all knowledge graph custom functions with SQLite.
10pub fn register_functions(conn: &rusqlite::Connection) -> Result<()> {
11    // Register kg_cosine_similarity - utility function that can be called from SQL
12    conn.create_scalar_function(
13        "kg_cosine_similarity",
14        2,
15        rusqlite::functions::FunctionFlags::SQLITE_UTF8,
16        |ctx| {
17            let vec1_blob: Vec<u8> = ctx.get(0)?;
18            let vec2_blob: Vec<u8> = ctx.get(1)?;
19
20            // Deserialize vectors from blobs
21            let mut vec1 = Vec::new();
22            for chunk in vec1_blob.chunks_exact(4) {
23                let bytes: [u8; 4] = match chunk.try_into() {
24                    Ok(b) => b,
25                    Err(_) => return Ok(0.0f64),
26                };
27                vec1.push(f32::from_le_bytes(bytes));
28            }
29
30            let mut vec2 = Vec::new();
31            for chunk in vec2_blob.chunks_exact(4) {
32                let bytes: [u8; 4] = match chunk.try_into() {
33                    Ok(b) => b,
34                    Err(_) => return Ok(0.0f64),
35                };
36                vec2.push(f32::from_le_bytes(bytes));
37            }
38
39            if vec1.len() != vec2.len() {
40                return Ok(0.0f64);
41            }
42
43            let mut dot_product = 0.0_f32;
44            let mut norm_a = 0.0_f32;
45            let mut norm_b = 0.0_f32;
46
47            for i in 0..vec1.len() {
48                dot_product += vec1[i] * vec2[i];
49                norm_a += vec1[i] * vec1[i];
50                norm_b += vec2[i] * vec2[i];
51            }
52
53            if norm_a == 0.0 || norm_b == 0.0 {
54                return Ok(0.0f64);
55            }
56
57            let similarity = dot_product / (norm_a.sqrt() * norm_b.sqrt());
58            Ok(similarity as f64)
59        },
60    )?;
61
62    Ok(())
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68    use rusqlite::{params, Connection};
69
70    #[test]
71    fn test_register_functions() {
72        let conn = Connection::open_in_memory().unwrap();
73        crate::schema::create_schema(&conn).unwrap();
74
75        // Verify registration succeeds
76        assert!(register_functions(&conn).is_ok());
77
78        // Test kg_cosine_similarity with identical vectors
79        let mut vec1: Vec<u8> = Vec::new();
80        vec1.extend_from_slice(&1.0_f32.to_le_bytes());
81        vec1.extend_from_slice(&0.0_f32.to_le_bytes());
82        vec1.extend_from_slice(&0.0_f32.to_le_bytes());
83        let vec2 = vec1.clone();
84
85        let sim: f64 = conn
86            .query_row(
87                "SELECT kg_cosine_similarity(?1, ?2)",
88                params![vec1, vec2],
89                |row| row.get(0),
90            )
91            .unwrap();
92        assert!((sim - 1.0).abs() < 0.001);
93    }
94}