1use anyhow::Result;
5use rusqlite::Connection;
6use std::path::Path;
7
8const CHUNK_TOKEN_TARGET: usize = 400;
9const CHUNK_OVERLAP: usize = 80;
10
11#[cfg(feature = "memory_vector")]
12use std::sync::Once;
13
14#[cfg(feature = "memory_vector")]
15static VEC_INIT: Once = Once::new();
16
17#[cfg(feature = "memory_vector")]
20pub fn ensure_vec_extension_loaded() {
21 VEC_INIT.call_once(|| unsafe {
22 rusqlite::ffi::sqlite3_auto_extension(Some(std::mem::transmute(
23 sqlite_vec::sqlite3_vec_init as *const (),
24 )));
25 });
26}
27
28pub fn index_path(workspace_root: &Path, agent_id: &str) -> std::path::PathBuf {
30 workspace_root
31 .join("memory")
32 .join(format!("{}.sqlite", agent_id))
33}
34
35pub fn ensure_index(conn: &Connection) -> Result<()> {
37 conn.execute_batch(
38 r#"
39 CREATE VIRTUAL TABLE IF NOT EXISTS memory_fts USING fts5(
40 path,
41 chunk_index,
42 content,
43 tokenize='porter'
44 );
45 "#,
46 )?;
47 Ok(())
48}
49
50#[cfg(feature = "memory_vector")]
54pub fn ensure_vec0_table(conn: &Connection, dimension: usize) -> Result<()> {
55 ensure_vec_extension_loaded();
56
57 conn.execute_batch(
59 r#"
60 CREATE TABLE IF NOT EXISTS _memory_vec_meta (
61 k TEXT PRIMARY KEY,
62 v INTEGER NOT NULL
63 );
64 "#,
65 )?;
66
67 let stored_dim: Option<i64> = conn
68 .query_row(
69 "SELECT v FROM _memory_vec_meta WHERE k = 'dimension'",
70 [],
71 |row| row.get(0),
72 )
73 .ok();
74
75 let need_recreate = match stored_dim {
76 Some(d) if d as usize == dimension => false,
77 _ => true,
78 };
79
80 if need_recreate {
81 conn.execute_batch("DROP TABLE IF EXISTS memory_vec")?;
82 conn.execute(
83 "INSERT OR REPLACE INTO _memory_vec_meta (k, v) VALUES ('dimension', ?)",
84 rusqlite::params![dimension as i64],
85 )?;
86
87 let sql = format!(
88 r#"CREATE VIRTUAL TABLE memory_vec USING vec0(
89 embedding float[{}],
90 path text,
91 chunk_index int,
92 +content text
93 )"#,
94 dimension
95 );
96 conn.execute_batch(&sql)?;
97 tracing::info!(
98 dimension,
99 "memory_vec table recreated for new embedding dimension"
100 );
101 } else {
102 let exists: bool = conn
104 .query_row(
105 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='memory_vec'",
106 [],
107 |row| row.get::<_, i64>(0),
108 )
109 .map(|n| n > 0)
110 .unwrap_or(false);
111
112 if !exists {
113 conn.execute(
114 "INSERT OR REPLACE INTO _memory_vec_meta (k, v) VALUES ('dimension', ?)",
115 rusqlite::params![dimension as i64],
116 )?;
117 let sql = format!(
118 r#"CREATE VIRTUAL TABLE memory_vec USING vec0(
119 embedding float[{}],
120 path text,
121 chunk_index int,
122 +content text
123 )"#,
124 dimension
125 );
126 conn.execute_batch(&sql)?;
127 }
128 }
129
130 Ok(())
131}
132
133#[cfg(feature = "memory_vector")]
135pub fn chunk_content_for_embed(content: &str) -> Vec<String> {
136 chunk_content(content)
137}
138
139fn chunk_content(content: &str) -> Vec<String> {
140 let paragraphs: Vec<&str> = content
141 .split("\n\n")
142 .filter(|s| !s.trim().is_empty())
143 .collect();
144 let mut chunks = Vec::new();
145 let mut current = String::new();
146 let mut token_approx = 0;
147
148 for p in paragraphs {
149 let p_tokens = p.len() / 4; if token_approx + p_tokens > CHUNK_TOKEN_TARGET && !current.is_empty() {
151 chunks.push(current.trim().to_string());
152 let words: Vec<&str> = current.split_whitespace().collect();
154 let overlap_start = words.len().saturating_sub(CHUNK_OVERLAP / 4);
155 current = words[overlap_start..].join(" ");
156 token_approx = current.len() / 4;
157 }
158 if !current.is_empty() {
159 current.push_str("\n\n");
160 }
161 current.push_str(p);
162 token_approx += p_tokens;
163 }
164 if !current.trim().is_empty() {
165 chunks.push(current.trim().to_string());
166 }
167 chunks
168}
169
170pub fn index_file(conn: &Connection, path: &str, content: &str) -> Result<()> {
173 conn.execute(
174 "DELETE FROM memory_fts WHERE path = ?",
175 rusqlite::params![path],
176 )?;
177 let chunks = chunk_content(content);
178 for (i, chunk) in chunks.iter().enumerate() {
179 conn.execute(
180 "INSERT INTO memory_fts(path, chunk_index, content) VALUES (?, ?, ?)",
181 rusqlite::params![path, i as i64, chunk],
182 )?;
183 }
184 Ok(())
185}
186
187#[cfg(feature = "memory_vector")]
189pub fn index_file_vec(
190 conn: &Connection,
191 path: &str,
192 chunks: &[String],
193 embeddings: &[Vec<f32>],
194) -> Result<()> {
195 use zerocopy::AsBytes;
196 if chunks.len() != embeddings.len() {
197 anyhow::bail!(
198 "Chunks and embeddings length mismatch: {} vs {}",
199 chunks.len(),
200 embeddings.len()
201 );
202 }
203 conn.execute(
204 "DELETE FROM memory_vec WHERE path = ?",
205 rusqlite::params![path],
206 )?;
207 let mut stmt = conn.prepare(
208 "INSERT INTO memory_vec(path, chunk_index, content, embedding) VALUES (?, ?, ?, ?)",
209 )?;
210 for (i, (chunk, emb)) in chunks.iter().zip(embeddings.iter()).enumerate() {
211 stmt.execute(rusqlite::params![path, i as i64, chunk, emb.as_bytes()])?;
212 }
213 Ok(())
214}
215
216pub fn search_bm25(conn: &Connection, query: &str, limit: i64) -> Result<Vec<MemoryHit>> {
218 let mut stmt = conn.prepare(
219 r#"
220 SELECT path, chunk_index, content, bm25(memory_fts) as rank
221 FROM memory_fts
222 WHERE memory_fts MATCH ?
223 ORDER BY rank
224 LIMIT ?
225 "#,
226 )?;
227 let rows = stmt.query_map(rusqlite::params![query, limit], |row| {
228 Ok(MemoryHit {
229 path: row.get(0)?,
230 chunk_index: row.get(1)?,
231 content: row.get(2)?,
232 score: row.get::<_, f64>(3).unwrap_or(0.0),
233 })
234 })?;
235 let mut hits: Vec<MemoryHit> = rows.filter_map(|r| r.ok()).collect();
236 hits.sort_by(|a, b| {
237 a.score
238 .partial_cmp(&b.score)
239 .unwrap_or(std::cmp::Ordering::Equal)
240 });
241 Ok(hits)
242}
243
244#[cfg(feature = "memory_vector")]
246pub fn search_vec(
247 conn: &Connection,
248 query_embedding: &[f32],
249 limit: i64,
250) -> Result<Vec<MemoryHit>> {
251 use zerocopy::AsBytes;
252 let mut stmt = conn.prepare(
253 r#"
254 SELECT path, chunk_index, content, distance
255 FROM memory_vec
256 WHERE embedding MATCH ?1
257 ORDER BY distance
258 LIMIT ?2
259 "#,
260 )?;
261 let rows = stmt.query_map(
262 rusqlite::params![query_embedding.as_bytes(), limit],
263 |row| {
264 Ok(MemoryHit {
265 path: row.get(0)?,
266 chunk_index: row.get(1)?,
267 content: row.get(2)?,
268 score: -row.get::<_, f64>(3).unwrap_or(0.0),
270 })
271 },
272 )?;
273 Ok(rows.filter_map(|r| r.ok()).collect())
274}
275
276#[cfg(feature = "memory_vector")]
278pub fn has_vec_index(conn: &Connection) -> bool {
279 conn.query_row("SELECT COUNT(*) FROM memory_vec", [], |row| {
280 row.get::<_, i64>(0)
281 })
282 .map(|n| n > 0)
283 .unwrap_or(false)
284}
285
286#[derive(Debug, Clone)]
287pub struct MemoryHit {
288 pub path: String,
289 pub chunk_index: i64,
290 pub content: String,
291 pub score: f64,
292}