Skip to main content

solo_storage/
reader.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! `ReaderPool`: pool of read-only SQLite connections backed by
4//! `deadpool-sqlite`. Each newly-created connection has its raw SQLCipher
5//! key bound via a `post_create` hook (PBKDF2 cost paid once per
6//! connection, not per query). See ADR-0003 §"Trait shapes" and §P8-A/P8-B.
7//!
8//! Pool size defaults to 2 per ADR-0003 §"pool size" — the SQLite WAL gives
9//! us snapshot isolation, but the single-writer constraint means more than
10//! a couple of concurrent reads against a normal hard drive saturate IO
11//! before they help latency.
12//!
13//! The same `Arc<dyn VectorIndex + Send + Sync>` that the writer holds is
14//! shared with the pool so vector search runs on whichever Tokio task
15//! happens to call `interact`. Concurrency is provided by `hnsw_rs`'s
16//! internal `parking_lot::RwLock`, not by the pool.
17//!
18//! ## A note on lifetime in tests
19//!
20//! `deadpool-sqlite::Pool::drop` schedules cleanup via the current Tokio
21//! runtime. If the runtime is torn down before the pool is dropped, that
22//! schedule call panics with "no reactor running". Production code never
23//! hits this — the pool lives for the daemon's lifetime, and the daemon
24//! exits only after the runtime shuts down gracefully. In tests, **build
25//! and drop the pool inside `runtime.block_on(async { ... })`** so the
26//! pool's drop runs while the runtime is still active.
27
28use std::path::Path;
29use std::sync::Arc;
30
31use deadpool_sqlite::{Config, Hook, HookError, Pool, Runtime};
32use solo_core::{Error, Result, VectorIndex};
33
34use crate::key_material::KeyMaterial;
35
36/// Default read pool size per ADR-0003.
37pub const DEFAULT_POOL_SIZE: usize = 2;
38
39/// Shared read pool. Cheap to clone (the inner `Pool` is `Arc`-based and
40/// the HNSW handle is `Arc<dyn VectorIndex>`). Cloning gives multiple
41/// owners — useful when the daemon hands one to the MCP server and keeps
42/// one for shutdown.
43#[derive(Clone)]
44pub struct ReaderPool {
45    pool: Pool,
46    hnsw: Arc<dyn VectorIndex + Send + Sync>,
47}
48
49impl ReaderPool {
50    pub fn new(
51        db_path: &Path,
52        key: Option<KeyMaterial>,
53        hnsw: Arc<dyn VectorIndex + Send + Sync>,
54    ) -> Result<Self> {
55        Self::with_size(db_path, key, DEFAULT_POOL_SIZE, hnsw)
56    }
57
58    pub fn with_size(
59        db_path: &Path,
60        key: Option<KeyMaterial>,
61        size: usize,
62        hnsw: Arc<dyn VectorIndex + Send + Sync>,
63    ) -> Result<Self> {
64        let cfg = Config::new(db_path);
65        let mut builder = cfg
66            .builder(Runtime::Tokio1)
67            .map_err(|e| Error::storage(format!("deadpool config: {e:?}")))?
68            .max_size(size);
69
70        if let Some(key) = key {
71            let key_hex = key.as_hex();
72            builder = builder.post_create(Hook::async_fn(move |conn, _metrics| {
73                let key_hex = key_hex.clone();
74                Box::pin(async move {
75                    // Wrap in Zeroizing<String> so the raw key bytes are
76                    // wiped from the heap on drop — parallels the
77                    // init.rs + backup.rs PRAGMA key bind sites.
78                    let pragma: zeroize::Zeroizing<String> =
79                        zeroize::Zeroizing::new(format!("PRAGMA key = \"x'{}'\"", &*key_hex));
80                    conn.interact(move |c| {
81                        c.execute_batch(&pragma)?;
82                        c.execute_batch(
83                            "PRAGMA foreign_keys = ON;
84                             PRAGMA busy_timeout = 5000;",
85                        )?;
86                        Ok::<_, rusqlite::Error>(())
87                    })
88                    .await
89                    .map_err(|e| HookError::message(format!("interact: {e}")))?
90                    .map_err(|e| HookError::message(format!("PRAGMA key: {e}")))?;
91                    Ok(())
92                })
93            }));
94        } else {
95            builder = builder.post_create(Hook::async_fn(|conn, _metrics| {
96                Box::pin(async move {
97                    conn.interact(|c| {
98                        c.execute_batch(
99                            "PRAGMA foreign_keys = ON;
100                             PRAGMA busy_timeout = 5000;",
101                        )
102                    })
103                    .await
104                    .map_err(|e| HookError::message(format!("interact: {e}")))?
105                    .map_err(|e| HookError::message(format!("PRAGMA setup: {e}")))?;
106                    Ok(())
107                })
108            }));
109        }
110
111        let pool = builder
112            .build()
113            .map_err(|e| Error::storage(format!("deadpool build: {e:?}")))?;
114        Ok(Self { pool, hnsw })
115    }
116
117    pub async fn interact<F, R>(&self, f: F) -> Result<R>
118    where
119        F: FnOnce(&mut rusqlite::Connection) -> rusqlite::Result<R> + Send + 'static,
120        R: Send + 'static,
121    {
122        let conn = self
123            .pool
124            .get()
125            .await
126            .map_err(|e| Error::storage(format!("pool get: {e:?}")))?;
127        conn.interact(f)
128            .await
129            .map_err(|e| Error::storage(format!("interact: {e}")))?
130            .map_err(|e| Error::storage(format!("rusqlite: {e}")))
131    }
132
133    pub fn hnsw(&self) -> &Arc<dyn VectorIndex + Send + Sync> {
134        &self.hnsw
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::test_support::{
142        StubVectorIndex, fixture_embedding, fixture_episode, open_test_db_at,
143    };
144    use crate::writer::WriterActor;
145
146    fn rt() -> tokio::runtime::Runtime {
147        tokio::runtime::Builder::new_multi_thread()
148            .worker_threads(2)
149            .enable_all()
150            .build()
151            .unwrap()
152    }
153
154    #[test]
155    fn pool_returns_connections() {
156        let runtime = rt();
157        let tmp = tempfile::TempDir::new().unwrap();
158        let path = tmp.path().join("test.db");
159        let _ = open_test_db_at(&path);
160
161        runtime.block_on(async {
162            // Pool's life cycle must run within the runtime; create + use +
163            // drop inside this async block so pool.drop sees a live reactor.
164            let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
165            let pool = ReaderPool::new(&path, None, hnsw).unwrap();
166            let n: u32 = pool
167                .interact(|conn| {
168                    conn.query_row("SELECT MAX(version) FROM schema_migrations", [], |row| {
169                        row.get(0)
170                    })
171                })
172                .await
173                .unwrap();
174            // v0.8.1 P1 raised the per-tenant chain to version 7
175            // (`triples.source_episode_id` for GDPR cascade); v0.8.0
176            // P5+P6 had taken us to 6 (principal-attribution columns).
177            assert_eq!(n, 10);
178        });
179    }
180
181    #[test]
182    fn reader_sees_writes_committed_through_writer_actor() {
183        let runtime = rt();
184        let tmp = tempfile::TempDir::new().unwrap();
185        let path = tmp.path().join("test.db");
186        let writer_conn = open_test_db_at(&path);
187        let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
188
189        // Writer can be spawned outside the runtime — it owns its OS thread.
190        let crate::writer::WriterSpawn { handle, join: _ } =
191            WriterActor::spawn(writer_conn, hnsw.clone());
192
193        runtime.block_on(async {
194            let pool = ReaderPool::new(&path, None, hnsw).unwrap();
195
196            let episode = fixture_episode("reader-visibility test");
197            let mid = handle
198                .remember(episode.clone(), fixture_embedding(4))
199                .await
200                .unwrap();
201            assert_eq!(mid, episode.memory_id);
202
203            let mid_str = mid.to_string();
204            let content: String = pool
205                .interact(move |conn| {
206                    conn.query_row(
207                        "SELECT content FROM episodes WHERE memory_id = ?",
208                        [mid_str],
209                        |row| row.get(0),
210                    )
211                })
212                .await
213                .unwrap();
214            assert_eq!(content, "reader-visibility test");
215        });
216
217        drop(handle);
218    }
219
220    #[test]
221    fn many_concurrent_reads_serve_from_pool() {
222        let runtime = rt();
223        let tmp = tempfile::TempDir::new().unwrap();
224        let path = tmp.path().join("test.db");
225        let _ = open_test_db_at(&path);
226
227        runtime.block_on(async {
228            let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
229            let pool = ReaderPool::with_size(&path, None, 4, hnsw).unwrap();
230
231            let mut tasks = Vec::new();
232            for _ in 0..32 {
233                let p = pool.pool.clone();
234                tasks.push(tokio::spawn(async move {
235                    let conn = p.get().await.unwrap();
236                    conn.interact(|c| {
237                        c.query_row("SELECT MAX(version) FROM schema_migrations", [], |row| {
238                            row.get::<_, u32>(0)
239                        })
240                    })
241                    .await
242                    .unwrap()
243                    .unwrap()
244                }));
245            }
246            let mut counts = Vec::new();
247            for t in tasks {
248                counts.push(t.await.unwrap());
249            }
250            assert_eq!(counts.len(), 32);
251            // v0.8.1 P1 raised the per-tenant chain to version 7
252            // (`triples.source_episode_id` for GDPR cascade); v0.8.0
253            // P5+P6 had taken us to 6 (principal-attribution columns).
254            assert!(counts.iter().all(|c| *c == 10));
255        });
256    }
257}