1use crate::{raw_statement::RawStatement, Connection, Result, Statement};
4use hashlink::LruCache;
5use std::{
6 cell::RefCell,
7 ops::{Deref, DerefMut},
8 sync::Arc,
9};
10
11impl Connection {
12 #[inline]
39 pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>> {
40 self.cache.get(self, sql)
41 }
42
43 #[inline]
49 pub fn set_prepared_statement_cache_capacity(&self, capacity: usize) {
50 self.cache.set_capacity(capacity);
51 }
52
53 #[inline]
55 pub fn flush_prepared_statement_cache(&self) {
56 self.cache.flush();
57 }
58}
59
60#[derive(Debug)]
62pub struct StatementCache(RefCell<LruCache<Arc<str>, RawStatement>>);
63
64#[allow(clippy::non_send_fields_in_send_ty)]
65unsafe impl Send for StatementCache {}
66
67pub struct CachedStatement<'conn> {
73 stmt: Option<Statement<'conn>>,
74 cache: &'conn StatementCache,
75}
76
77impl<'conn> Deref for CachedStatement<'conn> {
78 type Target = Statement<'conn>;
79
80 #[inline]
81 fn deref(&self) -> &Statement<'conn> {
82 self.stmt.as_ref().unwrap()
83 }
84}
85
86impl<'conn> DerefMut for CachedStatement<'conn> {
87 #[inline]
88 fn deref_mut(&mut self) -> &mut Statement<'conn> {
89 self.stmt.as_mut().unwrap()
90 }
91}
92
93impl Drop for CachedStatement<'_> {
94 #[allow(unused_must_use)]
95 #[inline]
96 fn drop(&mut self) {
97 if let Some(stmt) = self.stmt.take() {
98 self.cache.cache_stmt(unsafe { stmt.into_raw() });
99 }
100 }
101}
102
103impl CachedStatement<'_> {
104 #[inline]
105 fn new<'conn>(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn> {
106 CachedStatement {
107 stmt: Some(stmt),
108 cache,
109 }
110 }
111
112 #[inline]
115 pub fn discard(mut self) {
116 self.stmt = None;
117 }
118}
119
120impl StatementCache {
121 #[inline]
123 pub fn with_capacity(capacity: usize) -> StatementCache {
124 StatementCache(RefCell::new(LruCache::new(capacity)))
125 }
126
127 #[inline]
128 fn set_capacity(&self, capacity: usize) {
129 self.0.borrow_mut().set_capacity(capacity);
130 }
131
132 fn get<'conn>(&'conn self, conn: &'conn Connection, sql: &str) -> Result<CachedStatement<'conn>> {
140 let trimmed = sql.trim();
141 let mut cache = self.0.borrow_mut();
142 let stmt = match cache.remove(trimmed) {
143 Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)),
144 None => conn.prepare(trimmed),
145 };
146 stmt.map(|mut stmt| {
147 stmt.stmt.set_statement_cache_key(trimmed);
148 CachedStatement::new(stmt, self)
149 })
150 }
151
152 fn cache_stmt(&self, stmt: RawStatement) {
154 if stmt.is_null() {
155 return;
156 }
157 let mut cache = self.0.borrow_mut();
158 stmt.clear_bindings();
159 if let Some(sql) = stmt.statement_cache_key() {
160 cache.insert(sql, stmt);
161 } else {
162 debug_assert!(
163 false,
164 "bug in statement cache code, statement returned to cache that without key"
165 );
166 }
167 }
168
169 #[inline]
170 fn flush(&self) {
171 let mut cache = self.0.borrow_mut();
172 cache.clear();
173 }
174}
175
176#[cfg(test)]
177mod test {
178 use super::StatementCache;
179 use crate::{Connection, Result};
180 use fallible_iterator::FallibleIterator;
181
182 impl StatementCache {
183 fn clear(&self) {
184 self.0.borrow_mut().clear();
185 }
186
187 fn len(&self) -> usize {
188 self.0.borrow().len()
189 }
190
191 fn capacity(&self) -> usize {
192 self.0.borrow().capacity()
193 }
194 }
195
196 #[test]
197 fn test_cache() -> Result<()> {
198 let db = Connection::open_in_memory()?;
199 let cache = &db.cache;
200 let initial_capacity = cache.capacity();
201 assert_eq!(0, cache.len());
202 assert!(initial_capacity > 0);
203
204 let sql = "PRAGMA database_list";
205 {
206 let mut stmt = db.prepare_cached(sql)?;
207 assert_eq!(0, cache.len());
208 assert_eq!("memory", stmt.query_row([], |r| r.get::<_, String>(1))?);
209 }
210 assert_eq!(1, cache.len());
211
212 {
213 let mut stmt = db.prepare_cached(sql)?;
214 assert_eq!(0, cache.len());
215 assert_eq!("memory", stmt.query_row([], |r| r.get::<_, String>(1))?);
216 }
217 assert_eq!(1, cache.len());
218
219 cache.clear();
220 assert_eq!(0, cache.len());
221 assert_eq!(initial_capacity, cache.capacity());
222 Ok(())
223 }
224
225 #[test]
226 fn test_set_capacity() -> Result<()> {
227 let db = Connection::open_in_memory()?;
228 let cache = &db.cache;
229
230 let sql = "PRAGMA database_list";
231 {
232 let mut stmt = db.prepare_cached(sql)?;
233 assert_eq!(0, cache.len());
234 assert_eq!("memory", stmt.query_row([], |r| r.get::<_, String>(1))?);
235 }
236 assert_eq!(1, cache.len());
237
238 db.set_prepared_statement_cache_capacity(0);
239 assert_eq!(0, cache.len());
240
241 {
242 let mut stmt = db.prepare_cached(sql)?;
243 assert_eq!(0, cache.len());
244 assert_eq!("memory", stmt.query_row([], |r| r.get::<_, String>(1))?);
245 }
246 assert_eq!(0, cache.len());
247
248 db.set_prepared_statement_cache_capacity(8);
249 {
250 let mut stmt = db.prepare_cached(sql)?;
251 assert_eq!(0, cache.len());
252 assert_eq!("memory", stmt.query_row([], |r| r.get::<_, String>(1))?);
253 }
254 assert_eq!(1, cache.len());
255 Ok(())
256 }
257
258 #[test]
259 fn test_discard() -> Result<()> {
260 let db = Connection::open_in_memory()?;
261 let cache = &db.cache;
262
263 let sql = "PRAGMA database_list";
264 {
265 let mut stmt = db.prepare_cached(sql)?;
266 assert_eq!(0, cache.len());
267 assert_eq!("memory", stmt.query_row([], |r| r.get::<_, String>(1))?);
268 stmt.discard();
269 }
270 assert_eq!(0, cache.len());
271 Ok(())
272 }
273
274 #[test]
275 fn test_ddl() -> Result<()> {
276 let db = Connection::open_in_memory()?;
277 db.execute_batch(
278 r#"
279 CREATE TABLE foo (x INT);
280 INSERT INTO foo VALUES (1);
281 "#,
282 )?;
283
284 let sql = "SELECT * FROM foo";
285
286 {
287 let mut stmt = db.prepare_cached(sql)?;
288 assert_eq!(Ok(Some(1i32)), stmt.query([])?.map(|r| r.get(0)).next());
289 }
290
291 db.execute_batch(
292 r#"
293 ALTER TABLE foo ADD COLUMN y INT;
294 UPDATE foo SET y = 2;
295 "#,
296 )?;
297
298 {
299 let mut stmt = db.prepare_cached(sql)?;
301 assert_eq!(
302 Ok(Some((1i32, 2i32))),
303 stmt.query([])?.map(|r| <(i32, i32)>::try_from(r)).next()
304 );
305 }
306 Ok(())
307 }
308
309 #[test]
310 fn test_connection_close() -> Result<()> {
311 let conn = Connection::open_in_memory()?;
312 conn.prepare_cached("SELECT * FROM sqlite_master;")?;
313
314 conn.close().expect("connection not closed");
315 Ok(())
316 }
317
318 #[test]
319 fn test_cache_key() -> Result<()> {
320 let db = Connection::open_in_memory()?;
321 let cache = &db.cache;
322 assert_eq!(0, cache.len());
323
324 let sql = "PRAGMA database_list; ";
325 {
326 let mut stmt = db.prepare_cached(sql)?;
327 assert_eq!(0, cache.len());
328 assert_eq!("memory", stmt.query_row([], |r| r.get::<_, String>(1))?);
329 }
330 assert_eq!(1, cache.len());
331
332 {
333 let mut stmt = db.prepare_cached(sql)?;
334 assert_eq!(0, cache.len());
335 assert_eq!("memory", stmt.query_row([], |r| r.get::<_, String>(1))?);
336 }
337 assert_eq!(1, cache.len());
338 Ok(())
339 }
340
341 #[test]
342 fn test_cannot_prepare_empty_stmt() -> Result<()> {
343 let conn = Connection::open_in_memory()?;
344 let result = conn.prepare_cached("");
345 assert!(result.is_err());
346 Ok(())
347 }
348}