Skip to main content

semantic_memory/
db.rs

1//! Database initialization, migrations, and connection management.
2
3use crate::config::EmbeddingConfig;
4use crate::error::MemoryError;
5use rusqlite::{params, Connection};
6use std::path::Path;
7
8/// V1 migration: full schema.
9const MIGRATION_V1: &str = r#"
10-- CONVERSATIONS
11CREATE TABLE sessions (
12    id          TEXT PRIMARY KEY,
13    channel     TEXT NOT NULL DEFAULT 'repl',
14    created_at  TEXT NOT NULL DEFAULT (datetime('now')),
15    updated_at  TEXT NOT NULL DEFAULT (datetime('now')),
16    metadata    TEXT
17);
18
19CREATE INDEX idx_sessions_updated ON sessions(updated_at DESC);
20
21CREATE TABLE messages (
22    id          INTEGER PRIMARY KEY AUTOINCREMENT,
23    session_id  TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
24    role        TEXT NOT NULL CHECK (role IN ('system', 'user', 'assistant', 'tool')),
25    content     TEXT NOT NULL,
26    token_count INTEGER,
27    created_at  TEXT NOT NULL DEFAULT (datetime('now')),
28    metadata    TEXT
29);
30
31CREATE INDEX idx_messages_session ON messages(session_id, created_at ASC);
32CREATE INDEX idx_messages_created ON messages(created_at DESC);
33
34-- KNOWLEDGE (Facts)
35CREATE TABLE facts (
36    id          TEXT PRIMARY KEY,
37    namespace   TEXT NOT NULL DEFAULT 'general',
38    content     TEXT NOT NULL,
39    source      TEXT,
40    embedding   BLOB,
41    created_at  TEXT NOT NULL DEFAULT (datetime('now')),
42    updated_at  TEXT NOT NULL DEFAULT (datetime('now')),
43    metadata    TEXT
44);
45
46CREATE INDEX idx_facts_namespace ON facts(namespace);
47CREATE INDEX idx_facts_updated ON facts(updated_at DESC);
48
49CREATE TABLE facts_rowid_map (
50    rowid       INTEGER PRIMARY KEY AUTOINCREMENT,
51    fact_id     TEXT NOT NULL UNIQUE REFERENCES facts(id) ON DELETE CASCADE
52);
53
54CREATE VIRTUAL TABLE facts_fts USING fts5(
55    content,
56    content='',
57    content_rowid='rowid',
58    tokenize='porter unicode61'
59);
60
61-- DOCUMENTS (Chunked content)
62CREATE TABLE documents (
63    id          TEXT PRIMARY KEY,
64    title       TEXT NOT NULL,
65    source_path TEXT,
66    namespace   TEXT NOT NULL DEFAULT 'general',
67    created_at  TEXT NOT NULL DEFAULT (datetime('now')),
68    metadata    TEXT
69);
70
71CREATE TABLE chunks (
72    id          TEXT PRIMARY KEY,
73    document_id TEXT NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
74    chunk_index INTEGER NOT NULL,
75    content     TEXT NOT NULL,
76    token_count INTEGER,
77    embedding   BLOB,
78    created_at  TEXT NOT NULL DEFAULT (datetime('now'))
79);
80
81CREATE INDEX idx_chunks_document ON chunks(document_id, chunk_index ASC);
82
83CREATE TABLE chunks_rowid_map (
84    rowid       INTEGER PRIMARY KEY AUTOINCREMENT,
85    chunk_id    TEXT NOT NULL UNIQUE REFERENCES chunks(id) ON DELETE CASCADE
86);
87
88CREATE VIRTUAL TABLE chunks_fts USING fts5(
89    content,
90    content='',
91    content_rowid='rowid',
92    tokenize='porter unicode61'
93);
94
95-- EMBEDDING METADATA
96CREATE TABLE embedding_metadata (
97    id          INTEGER PRIMARY KEY CHECK (id = 1),
98    model_name  TEXT NOT NULL,
99    dimensions  INTEGER NOT NULL,
100    updated_at  TEXT NOT NULL DEFAULT (datetime('now'))
101);
102"#;
103
104/// V2 migration: message embeddings for conversation search.
105const MIGRATION_V2: &str = r#"
106-- V2: Message embeddings for conversation search
107ALTER TABLE messages ADD COLUMN embedding BLOB;
108
109CREATE TABLE messages_rowid_map (
110    rowid       INTEGER PRIMARY KEY AUTOINCREMENT,
111    message_id  INTEGER NOT NULL UNIQUE REFERENCES messages(id) ON DELETE CASCADE
112);
113
114CREATE VIRTUAL TABLE messages_fts USING fts5(
115    content,
116    content='',
117    content_rowid='rowid',
118    tokenize='porter unicode61'
119);
120"#;
121
122/// V3 migration: embedding staleness tracking.
123const MIGRATION_V3: &str = r#"
124-- V3: Embedding staleness tracking
125ALTER TABLE embedding_metadata ADD COLUMN embeddings_dirty INTEGER NOT NULL DEFAULT 0;
126"#;
127
128/// V4 migration: HNSW metadata tracking.
129const MIGRATION_V4: &str = r#"
130-- V4: HNSW index metadata
131CREATE TABLE IF NOT EXISTS hnsw_metadata (
132    key TEXT PRIMARY KEY,
133    value TEXT NOT NULL
134);
135"#;
136
137/// V5 migration: quantized embeddings + HNSW keymap persistence.
138const MIGRATION_V5: &str = r#"
139-- V5: Quantized embeddings + HNSW keymap persistence
140ALTER TABLE facts ADD COLUMN embedding_q8 BLOB;
141ALTER TABLE chunks ADD COLUMN embedding_q8 BLOB;
142ALTER TABLE messages ADD COLUMN embedding_q8 BLOB;
143
144CREATE TABLE IF NOT EXISTS hnsw_keymap (
145    node_id     INTEGER PRIMARY KEY,
146    item_key    TEXT NOT NULL UNIQUE,
147    deleted     INTEGER NOT NULL DEFAULT 0
148);
149
150CREATE INDEX idx_hnsw_keymap_key ON hnsw_keymap(item_key);
151"#;
152
153/// Run a closure inside an `unchecked_transaction`, committing on success.
154///
155/// SAFETY: We hold &Connection (not &mut) via Mutex::lock(). unchecked_transaction()
156/// is required because transaction() needs &mut self. The Mutex serializes all access,
157/// preventing concurrent transaction nesting.
158pub fn with_transaction<F, T>(conn: &Connection, f: F) -> Result<T, MemoryError>
159where
160    F: FnOnce(&rusqlite::Transaction<'_>) -> Result<T, MemoryError>,
161{
162    let tx = conn.unchecked_transaction()?;
163    let result = f(&tx)?;
164    tx.commit()?;
165    Ok(result)
166}
167
168/// Open or create a SQLite database at `path`, configure pragmas, and run migrations.
169pub fn open_database(path: &Path) -> Result<Connection, MemoryError> {
170    // Create parent directories if needed
171    if let Some(parent) = path.parent() {
172        if !parent.as_os_str().is_empty() {
173            std::fs::create_dir_all(parent).map_err(|e| {
174                MemoryError::Other(format!(
175                    "Failed to create database directory {}: {}",
176                    parent.display(),
177                    e
178                ))
179            })?;
180        }
181    }
182
183    let conn = Connection::open(path)?;
184
185    // Set pragmas BEFORE any other operation
186    conn.execute_batch(
187        "PRAGMA journal_mode = WAL;
188         PRAGMA foreign_keys = ON;
189         PRAGMA busy_timeout = 5000;
190         PRAGMA synchronous = NORMAL;",
191    )?;
192
193    run_migrations(&conn)?;
194
195    Ok(conn)
196}
197
198/// Ordered list of migrations. Each entry is (version, SQL).
199const MIGRATIONS: &[(u32, &str)] = &[
200    (1, MIGRATION_V1),
201    (2, MIGRATION_V2),
202    (3, MIGRATION_V3),
203    (4, MIGRATION_V4),
204    (5, MIGRATION_V5),
205];
206
207/// Run all pending migrations.
208pub fn run_migrations(conn: &Connection) -> Result<(), MemoryError> {
209    // Create version table if it doesn't exist
210    conn.execute_batch(
211        "CREATE TABLE IF NOT EXISTS _schema_version (
212            version     INTEGER PRIMARY KEY,
213            applied_at  TEXT NOT NULL DEFAULT (datetime('now'))
214        );",
215    )?;
216
217    for &(version, sql) in MIGRATIONS {
218        let current_version: u32 = conn
219            .query_row(
220                "SELECT COALESCE(MAX(version), 0) FROM _schema_version",
221                [],
222                |row| row.get(0),
223            )
224            .unwrap_or(0);
225
226        if current_version >= version {
227            continue;
228        }
229
230        with_transaction(conn, |tx| {
231            tx.execute_batch(sql)
232                .map_err(|e| MemoryError::MigrationFailed {
233                    version,
234                    reason: e.to_string(),
235                })?;
236
237            tx.execute(
238                "INSERT INTO _schema_version (version) VALUES (?1)",
239                params![version],
240            )
241            .map_err(|e| MemoryError::MigrationFailed {
242                version,
243                reason: e.to_string(),
244            })?;
245
246            Ok(())
247        })?;
248
249        tracing::info!("Applied migration V{}", version);
250    }
251
252    Ok(())
253}
254
255/// Check and update the embedding_metadata singleton row.
256///
257/// If the row exists and model/dimensions don't match, warn and update.
258/// If no row exists, insert one. If it matches, no-op.
259pub fn check_embedding_metadata(
260    conn: &Connection,
261    config: &EmbeddingConfig,
262) -> Result<(), MemoryError> {
263    let existing: Option<(String, usize)> = conn
264        .query_row(
265            "SELECT model_name, dimensions FROM embedding_metadata WHERE id = 1",
266            [],
267            |row| Ok((row.get(0)?, row.get(1)?)),
268        )
269        .ok();
270
271    match existing {
272        Some((model, dims)) => {
273            if model != config.model || dims != config.dimensions {
274                tracing::warn!(
275                    stored_model = %model,
276                    stored_dims = dims,
277                    configured_model = %config.model,
278                    configured_dims = config.dimensions,
279                    "Embedding model changed. Existing embeddings are invalid. Call reembed_all() to re-embed."
280                );
281                conn.execute(
282                    "UPDATE embedding_metadata SET model_name = ?1, dimensions = ?2, \
283                     embeddings_dirty = 1, updated_at = datetime('now') WHERE id = 1",
284                    params![config.model, config.dimensions],
285                )?;
286            }
287        }
288        None => {
289            conn.execute(
290                "INSERT INTO embedding_metadata (id, model_name, dimensions) VALUES (1, ?1, ?2)",
291                params![config.model, config.dimensions],
292            )?;
293        }
294    }
295
296    Ok(())
297}
298
299/// Encode an f32 slice as bytes for SQLite BLOB storage.
300pub fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
301    let mut bytes = Vec::with_capacity(embedding.len() * 4);
302    for &val in embedding {
303        bytes.extend_from_slice(&val.to_le_bytes());
304    }
305    bytes
306}
307
308/// Check if embeddings are stale after a model change.
309pub fn is_embeddings_dirty(conn: &Connection) -> Result<bool, MemoryError> {
310    let dirty: i32 = conn
311        .query_row(
312            "SELECT COALESCE(embeddings_dirty, 0) FROM embedding_metadata WHERE id = 1",
313            [],
314            |row| row.get(0),
315        )
316        .unwrap_or(0);
317    Ok(dirty != 0)
318}
319
320/// Clear the dirty flag after re-embedding.
321pub fn clear_embeddings_dirty(conn: &Connection) -> Result<(), MemoryError> {
322    conn.execute(
323        "UPDATE embedding_metadata SET embeddings_dirty = 0 WHERE id = 1",
324        [],
325    )?;
326    Ok(())
327}
328
329/// Decode a SQLite BLOB back to f32 values.
330///
331/// Returns an error if the byte length is not divisible by 4.
332/// Uses `bytemuck::try_cast_slice` for zero-copy decoding when alignment permits,
333/// falling back to manual decode otherwise.
334#[allow(clippy::manual_is_multiple_of)] // MSRV 1.75: is_multiple_of stabilized later
335pub fn bytes_to_embedding(bytes: &[u8]) -> Result<Vec<f32>, MemoryError> {
336    if bytes.len() % 4 != 0 {
337        return Err(MemoryError::InvalidEmbedding {
338            expected_bytes: bytes.len() - (bytes.len() % 4),
339            actual_bytes: bytes.len(),
340        });
341    }
342    // Heap-allocated Vec<u8> from SQLite is aligned, so cast_slice succeeds.
343    // If alignment is off (shouldn't happen), fall back to manual decode.
344    match bytemuck::try_cast_slice::<u8, f32>(bytes) {
345        Ok(slice) => Ok(slice.to_vec()),
346        Err(_) => {
347            let mut embedding = Vec::with_capacity(bytes.len() / 4);
348            for chunk in bytes.chunks_exact(4) {
349                embedding.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
350            }
351            Ok(embedding)
352        }
353    }
354}