1use std::path::Path;
29use std::sync::Arc;
30
31use deadpool_sqlite::{Config, Hook, HookError, Pool, Runtime};
32use solo_core::{Error, Result, VectorIndex};
33
34use crate::key_material::KeyMaterial;
35
36pub const DEFAULT_POOL_SIZE: usize = 2;
38
39#[derive(Clone)]
44pub struct ReaderPool {
45 pool: Pool,
46 hnsw: Arc<dyn VectorIndex + Send + Sync>,
47}
48
49impl ReaderPool {
50 pub fn new(
51 db_path: &Path,
52 key: Option<KeyMaterial>,
53 hnsw: Arc<dyn VectorIndex + Send + Sync>,
54 ) -> Result<Self> {
55 Self::with_size(db_path, key, DEFAULT_POOL_SIZE, hnsw)
56 }
57
58 pub fn with_size(
59 db_path: &Path,
60 key: Option<KeyMaterial>,
61 size: usize,
62 hnsw: Arc<dyn VectorIndex + Send + Sync>,
63 ) -> Result<Self> {
64 let cfg = Config::new(db_path);
65 let mut builder = cfg
66 .builder(Runtime::Tokio1)
67 .map_err(|e| Error::storage(format!("deadpool config: {e:?}")))?
68 .max_size(size);
69
70 if let Some(key) = key {
71 let key_hex = key.as_hex();
72 builder = builder.post_create(Hook::async_fn(move |conn, _metrics| {
73 let key_hex = key_hex.clone();
74 Box::pin(async move {
75 let pragma: zeroize::Zeroizing<String> =
79 zeroize::Zeroizing::new(format!("PRAGMA key = \"x'{}'\"", &*key_hex));
80 conn.interact(move |c| {
81 c.execute_batch(&pragma)?;
82 c.execute_batch(
83 "PRAGMA foreign_keys = ON;
84 PRAGMA busy_timeout = 5000;",
85 )?;
86 Ok::<_, rusqlite::Error>(())
87 })
88 .await
89 .map_err(|e| HookError::message(format!("interact: {e}")))?
90 .map_err(|e| HookError::message(format!("PRAGMA key: {e}")))?;
91 Ok(())
92 })
93 }));
94 } else {
95 builder = builder.post_create(Hook::async_fn(|conn, _metrics| {
96 Box::pin(async move {
97 conn.interact(|c| {
98 c.execute_batch(
99 "PRAGMA foreign_keys = ON;
100 PRAGMA busy_timeout = 5000;",
101 )
102 })
103 .await
104 .map_err(|e| HookError::message(format!("interact: {e}")))?
105 .map_err(|e| HookError::message(format!("PRAGMA setup: {e}")))?;
106 Ok(())
107 })
108 }));
109 }
110
111 let pool = builder
112 .build()
113 .map_err(|e| Error::storage(format!("deadpool build: {e:?}")))?;
114 Ok(Self { pool, hnsw })
115 }
116
117 pub async fn interact<F, R>(&self, f: F) -> Result<R>
118 where
119 F: FnOnce(&mut rusqlite::Connection) -> rusqlite::Result<R> + Send + 'static,
120 R: Send + 'static,
121 {
122 let conn = self
123 .pool
124 .get()
125 .await
126 .map_err(|e| Error::storage(format!("pool get: {e:?}")))?;
127 conn.interact(f)
128 .await
129 .map_err(|e| Error::storage(format!("interact: {e}")))?
130 .map_err(|e| Error::storage(format!("rusqlite: {e}")))
131 }
132
133 pub fn hnsw(&self) -> &Arc<dyn VectorIndex + Send + Sync> {
134 &self.hnsw
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::test_support::{
142 StubVectorIndex, fixture_embedding, fixture_episode, open_test_db_at,
143 };
144 use crate::writer::WriterActor;
145
146 fn rt() -> tokio::runtime::Runtime {
147 tokio::runtime::Builder::new_multi_thread()
148 .worker_threads(2)
149 .enable_all()
150 .build()
151 .unwrap()
152 }
153
154 #[test]
155 fn pool_returns_connections() {
156 let runtime = rt();
157 let tmp = tempfile::TempDir::new().unwrap();
158 let path = tmp.path().join("test.db");
159 let _ = open_test_db_at(&path);
160
161 runtime.block_on(async {
162 let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
165 let pool = ReaderPool::new(&path, None, hnsw).unwrap();
166 let n: u32 = pool
167 .interact(|conn| {
168 conn.query_row("SELECT MAX(version) FROM schema_migrations", [], |row| {
169 row.get(0)
170 })
171 })
172 .await
173 .unwrap();
174 assert_eq!(n, 10);
178 });
179 }
180
181 #[test]
182 fn reader_sees_writes_committed_through_writer_actor() {
183 let runtime = rt();
184 let tmp = tempfile::TempDir::new().unwrap();
185 let path = tmp.path().join("test.db");
186 let writer_conn = open_test_db_at(&path);
187 let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
188
189 let crate::writer::WriterSpawn { handle, join: _ } =
191 WriterActor::spawn(writer_conn, hnsw.clone());
192
193 runtime.block_on(async {
194 let pool = ReaderPool::new(&path, None, hnsw).unwrap();
195
196 let episode = fixture_episode("reader-visibility test");
197 let mid = handle
198 .remember(episode.clone(), fixture_embedding(4))
199 .await
200 .unwrap();
201 assert_eq!(mid, episode.memory_id);
202
203 let mid_str = mid.to_string();
204 let content: String = pool
205 .interact(move |conn| {
206 conn.query_row(
207 "SELECT content FROM episodes WHERE memory_id = ?",
208 [mid_str],
209 |row| row.get(0),
210 )
211 })
212 .await
213 .unwrap();
214 assert_eq!(content, "reader-visibility test");
215 });
216
217 drop(handle);
218 }
219
220 #[test]
221 fn many_concurrent_reads_serve_from_pool() {
222 let runtime = rt();
223 let tmp = tempfile::TempDir::new().unwrap();
224 let path = tmp.path().join("test.db");
225 let _ = open_test_db_at(&path);
226
227 runtime.block_on(async {
228 let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
229 let pool = ReaderPool::with_size(&path, None, 4, hnsw).unwrap();
230
231 let mut tasks = Vec::new();
232 for _ in 0..32 {
233 let p = pool.pool.clone();
234 tasks.push(tokio::spawn(async move {
235 let conn = p.get().await.unwrap();
236 conn.interact(|c| {
237 c.query_row("SELECT MAX(version) FROM schema_migrations", [], |row| {
238 row.get::<_, u32>(0)
239 })
240 })
241 .await
242 .unwrap()
243 .unwrap()
244 }));
245 }
246 let mut counts = Vec::new();
247 for t in tasks {
248 counts.push(t.await.unwrap());
249 }
250 assert_eq!(counts.len(), 32);
251 assert!(counts.iter().all(|c| *c == 10));
255 });
256 }
257}