Skip to main content

sqlite_graphrag/storage/
chunks.rs

1// src/storage/chunks.rs
2// Chunk storage for bodies exceeding 512 tokens E5 limit
3
4use crate::embedder::f32_to_bytes;
5use crate::errors::AppError;
6use rusqlite::{params, Connection};
7
8#[derive(Debug, Clone)]
9pub struct Chunk {
10    pub memory_id: i64,
11    pub chunk_idx: i32,
12    pub chunk_text: String,
13    pub start_offset: i32,
14    pub end_offset: i32,
15    pub token_count: i32,
16}
17
18pub fn insert_chunks(conn: &Connection, chunks: &[Chunk]) -> Result<(), AppError> {
19    for chunk in chunks {
20        conn.execute(
21            "INSERT INTO memory_chunks (memory_id, chunk_idx, chunk_text, start_offset, end_offset, token_count)
22             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
23            params![
24                chunk.memory_id,
25                chunk.chunk_idx,
26                chunk.chunk_text,
27                chunk.start_offset,
28                chunk.end_offset,
29                chunk.token_count,
30            ],
31        )?;
32    }
33    Ok(())
34}
35
36pub fn upsert_chunk_vec(
37    conn: &Connection,
38    _rowid: i64,
39    memory_id: i64,
40    chunk_idx: i32,
41    embedding: &[f32],
42) -> Result<(), AppError> {
43    conn.execute(
44        "INSERT OR REPLACE INTO vec_chunks(rowid, memory_id, chunk_idx, embedding)
45         VALUES (
46             (SELECT id FROM memory_chunks WHERE memory_id = ?1 AND chunk_idx = ?2),
47             ?1, ?2, ?3
48         )",
49        params![memory_id, chunk_idx, f32_to_bytes(embedding)],
50    )?;
51    Ok(())
52}
53
54pub fn delete_chunks(conn: &Connection, memory_id: i64) -> Result<(), AppError> {
55    conn.execute(
56        "DELETE FROM memory_chunks WHERE memory_id = ?1",
57        params![memory_id],
58    )?;
59    Ok(())
60}
61
62pub fn knn_search_chunks(
63    conn: &Connection,
64    embedding: &[f32],
65    k: usize,
66) -> Result<Vec<(i64, i32, f32)>, AppError> {
67    let bytes = f32_to_bytes(embedding);
68    let mut stmt = conn.prepare(
69        "SELECT memory_id, chunk_idx, distance FROM vec_chunks
70         WHERE embedding MATCH ?1
71         ORDER BY distance LIMIT ?2",
72    )?;
73    let rows = stmt
74        .query_map(params![bytes, k as i64], |r| {
75            Ok((
76                r.get::<_, i64>(0)?,
77                r.get::<_, i32>(1)?,
78                r.get::<_, f32>(2)?,
79            ))
80        })?
81        .collect::<Result<Vec<_>, _>>()?;
82    Ok(rows)
83}
84
85pub fn get_chunks_by_memory(conn: &Connection, memory_id: i64) -> Result<Vec<Chunk>, AppError> {
86    let mut stmt = conn.prepare(
87        "SELECT memory_id, chunk_idx, chunk_text, start_offset, end_offset, token_count
88         FROM memory_chunks WHERE memory_id = ?1 ORDER BY chunk_idx",
89    )?;
90    let rows = stmt
91        .query_map(params![memory_id], |r| {
92            Ok(Chunk {
93                memory_id: r.get(0)?,
94                chunk_idx: r.get(1)?,
95                chunk_text: r.get(2)?,
96                start_offset: r.get(3)?,
97                end_offset: r.get(4)?,
98                token_count: r.get(5)?,
99            })
100        })?
101        .collect::<Result<Vec<_>, _>>()?;
102    Ok(rows)
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::constants::EMBEDDING_DIM;
109    use crate::storage::connection::register_vec_extension;
110    use rusqlite::Connection;
111    use tempfile::TempDir;
112
113    fn setup_db() -> (TempDir, Connection) {
114        register_vec_extension();
115        let tmp = TempDir::new().unwrap();
116        let db_path = tmp.path().join("test.db");
117        let mut conn = Connection::open(&db_path).unwrap();
118        crate::migrations::runner().run(&mut conn).unwrap();
119        (tmp, conn)
120    }
121
122    fn insert_memory(conn: &Connection) -> i64 {
123        conn.execute(
124            "INSERT INTO memories (namespace, name, type, description, body, body_hash)
125             VALUES ('global', 'test-mem', 'user', 'desc', 'body', 'hash1')",
126            [],
127        )
128        .unwrap();
129        conn.last_insert_rowid()
130    }
131
132    #[test]
133    fn test_insert_chunks_vazia_ok() {
134        let (_tmp, conn) = setup_db();
135        let resultado = insert_chunks(&conn, &[]);
136        assert!(resultado.is_ok());
137    }
138
139    #[test]
140    fn test_insert_chunks_e_get_por_memory() {
141        let (_tmp, conn) = setup_db();
142        let memory_id = insert_memory(&conn);
143
144        let chunks = vec![
145            Chunk {
146                memory_id,
147                chunk_idx: 0,
148                chunk_text: "primeiro chunk".to_string(),
149                start_offset: 0,
150                end_offset: 14,
151                token_count: 3,
152            },
153            Chunk {
154                memory_id,
155                chunk_idx: 1,
156                chunk_text: "segundo chunk".to_string(),
157                start_offset: 15,
158                end_offset: 28,
159                token_count: 3,
160            },
161        ];
162
163        insert_chunks(&conn, &chunks).unwrap();
164
165        let recuperados = get_chunks_by_memory(&conn, memory_id).unwrap();
166        assert_eq!(recuperados.len(), 2);
167        assert_eq!(recuperados[0].chunk_idx, 0);
168        assert_eq!(recuperados[0].chunk_text, "primeiro chunk");
169        assert_eq!(recuperados[0].start_offset, 0);
170        assert_eq!(recuperados[0].end_offset, 14);
171        assert_eq!(recuperados[0].token_count, 3);
172        assert_eq!(recuperados[1].chunk_idx, 1);
173        assert_eq!(recuperados[1].chunk_text, "segundo chunk");
174    }
175
176    #[test]
177    fn test_get_chunks_memory_inexistente_retorna_vazio() {
178        let (_tmp, conn) = setup_db();
179        let resultado = get_chunks_by_memory(&conn, 9999).unwrap();
180        assert!(resultado.is_empty());
181    }
182
183    #[test]
184    fn test_delete_chunks_remove_todos() {
185        let (_tmp, conn) = setup_db();
186        let memory_id = insert_memory(&conn);
187
188        let chunks = vec![
189            Chunk {
190                memory_id,
191                chunk_idx: 0,
192                chunk_text: "chunk a".to_string(),
193                start_offset: 0,
194                end_offset: 7,
195                token_count: 2,
196            },
197            Chunk {
198                memory_id,
199                chunk_idx: 1,
200                chunk_text: "chunk b".to_string(),
201                start_offset: 8,
202                end_offset: 15,
203                token_count: 2,
204            },
205        ];
206        insert_chunks(&conn, &chunks).unwrap();
207
208        delete_chunks(&conn, memory_id).unwrap();
209
210        let recuperados = get_chunks_by_memory(&conn, memory_id).unwrap();
211        assert!(recuperados.is_empty());
212    }
213
214    #[test]
215    fn test_delete_chunks_memory_sem_chunks_ok() {
216        let (_tmp, conn) = setup_db();
217        let resultado = delete_chunks(&conn, 9999);
218        assert!(resultado.is_ok());
219    }
220
221    #[test]
222    fn test_get_chunks_ordenados_por_chunk_idx() {
223        let (_tmp, conn) = setup_db();
224        let memory_id = insert_memory(&conn);
225
226        let chunks = vec![
227            Chunk {
228                memory_id,
229                chunk_idx: 2,
230                chunk_text: "terceiro".to_string(),
231                start_offset: 20,
232                end_offset: 28,
233                token_count: 1,
234            },
235            Chunk {
236                memory_id,
237                chunk_idx: 0,
238                chunk_text: "primeiro".to_string(),
239                start_offset: 0,
240                end_offset: 8,
241                token_count: 1,
242            },
243            Chunk {
244                memory_id,
245                chunk_idx: 1,
246                chunk_text: "segundo".to_string(),
247                start_offset: 9,
248                end_offset: 16,
249                token_count: 1,
250            },
251        ];
252        insert_chunks(&conn, &chunks).unwrap();
253
254        let recuperados = get_chunks_by_memory(&conn, memory_id).unwrap();
255        assert_eq!(recuperados.len(), 3);
256        assert_eq!(recuperados[0].chunk_idx, 0);
257        assert_eq!(recuperados[1].chunk_idx, 1);
258        assert_eq!(recuperados[2].chunk_idx, 2);
259    }
260
261    #[test]
262    fn test_upsert_chunk_vec_e_knn_search() {
263        let (_tmp, conn) = setup_db();
264        let memory_id = insert_memory(&conn);
265
266        let chunk = Chunk {
267            memory_id,
268            chunk_idx: 0,
269            chunk_text: "embedding test".to_string(),
270            start_offset: 0,
271            end_offset: 14,
272            token_count: 2,
273        };
274        insert_chunks(&conn, &[chunk]).unwrap();
275
276        let mut embedding = vec![0.0f32; EMBEDDING_DIM];
277        embedding[0] = 1.0;
278
279        let chunk_id: i64 = conn
280            .query_row(
281                "SELECT id FROM memory_chunks WHERE memory_id = ?1 AND chunk_idx = 0",
282                params![memory_id],
283                |r| r.get(0),
284            )
285            .unwrap();
286
287        upsert_chunk_vec(&conn, chunk_id, memory_id, 0, &embedding).unwrap();
288
289        let resultados = knn_search_chunks(&conn, &embedding, 1).unwrap();
290        assert_eq!(resultados.len(), 1);
291        assert_eq!(resultados[0].0, memory_id);
292        assert_eq!(resultados[0].1, 0);
293    }
294
295    #[test]
296    fn test_knn_search_chunks_sem_dados_retorna_vazio() {
297        let (_tmp, conn) = setup_db();
298        let embedding = vec![0.0f32; EMBEDDING_DIM];
299        let resultado = knn_search_chunks(&conn, &embedding, 5).unwrap();
300        assert!(resultado.is_empty());
301    }
302
303    #[test]
304    fn test_insert_chunks_fk_invalida_falha() {
305        let (_tmp, conn) = setup_db();
306        let chunk = Chunk {
307            memory_id: 99999,
308            chunk_idx: 0,
309            chunk_text: "sem pai".to_string(),
310            start_offset: 0,
311            end_offset: 7,
312            token_count: 1,
313        };
314        let resultado = insert_chunks(&conn, &[chunk]);
315        assert!(resultado.is_err());
316    }
317}