Skip to main content

duckdb/
r2d2.rs

1#![deny(warnings)]
2//! # Duckdb-rs support for the `r2d2` connection pool.
3//!
4//!
5//! Integrated with: [r2d2](https://crates.io/crates/r2d2)
6//!
7//!
8//! ## Example
9//!
10//! ```rust,no_run
11//! use std::thread;
12//! use duckdb::{DuckdbConnectionManager, params};
13//! use r2d2;
14//!
15//! let manager = DuckdbConnectionManager::file("file.db").unwrap();
16//! let pool = r2d2::Pool::new(manager).unwrap();
17//! pool.get()
18//!     .unwrap()
19//!     .execute("CREATE TABLE IF NOT EXISTS foo (bar INTEGER)", params![])
20//!     .unwrap();
21//!
22//! (0..10)
23//!     .map(|i| {
24//!         let pool = pool.clone();
25//!         thread::spawn(move || {
26//!             let conn = pool.get().unwrap();
27//!             conn.execute("INSERT INTO foo (bar) VALUES (?)", &[&i])
28//!                 .unwrap();
29//!         })
30//!     })
31//!     .collect::<Vec<_>>()
32//!     .into_iter()
33//!     .map(thread::JoinHandle::join)
34//!     .collect::<Result<_, _>>()
35//!     .unwrap()
36//! ```
37use crate::{Config, Connection, Error, Result};
38use std::{
39    path::Path,
40    sync::{Arc, Mutex},
41};
42
43#[cfg(feature = "vscalar")]
44use crate::vscalar::VScalar;
45#[cfg(feature = "vscalar")]
46use std::fmt::Debug;
47
48#[cfg(feature = "vtab")]
49use crate::vtab::VTab;
50
51/// An `r2d2::ManageConnection` for `duckdb::Connection`s.
52pub struct DuckdbConnectionManager {
53    connection: Arc<Mutex<Connection>>,
54}
55
56impl DuckdbConnectionManager {
57    /// Creates a new `DuckdbConnectionManager` from file.
58    pub fn file<P: AsRef<Path>>(path: P) -> Result<Self> {
59        Ok(Self {
60            connection: Arc::new(Mutex::new(Connection::open(path)?)),
61        })
62    }
63    /// Creates a new `DuckdbConnectionManager` from file with flags.
64    pub fn file_with_flags<P: AsRef<Path>>(path: P, config: Config) -> Result<Self> {
65        Ok(Self {
66            connection: Arc::new(Mutex::new(Connection::open_with_flags(path, config)?)),
67        })
68    }
69
70    /// Creates a new `DuckdbConnectionManager` from memory.
71    pub fn memory() -> Result<Self> {
72        Ok(Self {
73            connection: Arc::new(Mutex::new(Connection::open_in_memory()?)),
74        })
75    }
76
77    /// Creates a new `DuckdbConnectionManager` from memory with flags.
78    pub fn memory_with_flags(config: Config) -> Result<Self> {
79        Ok(Self {
80            connection: Arc::new(Mutex::new(Connection::open_in_memory_with_flags(config)?)),
81        })
82    }
83
84    /// Register a table function.
85    #[cfg(feature = "vtab")]
86    pub fn register_table_function<T: VTab>(&self, name: &str) -> Result<()> {
87        let conn = self.connection.lock().unwrap();
88        conn.register_table_function::<T>(name)
89    }
90
91    /// Register a scalar function.
92    #[cfg(feature = "vscalar")]
93    pub fn register_scalar_function<S: VScalar>(&self, name: &str) -> Result<()>
94    where
95        S::State: Debug + Default,
96    {
97        let conn = self.connection.lock().unwrap();
98        conn.register_scalar_function::<S>(name)
99    }
100}
101
102impl r2d2::ManageConnection for DuckdbConnectionManager {
103    type Connection = Connection;
104    type Error = Error;
105
106    fn connect(&self) -> Result<Self::Connection, Self::Error> {
107        let conn = self.connection.lock().unwrap();
108        conn.try_clone()
109    }
110
111    fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
112        conn.execute_batch("")
113    }
114
115    fn has_broken(&self, _: &mut Self::Connection) -> bool {
116        false
117    }
118}
119
120#[cfg(test)]
121mod test {
122    use super::*;
123    use crate::types::Value;
124    use std::{sync::mpsc, thread};
125
126    use tempfile::TempDir;
127
128    #[test]
129    fn test_basic() -> Result<()> {
130        let manager = DuckdbConnectionManager::file(":memory:")?;
131        let pool = r2d2::Pool::builder().max_size(2).build(manager).unwrap();
132
133        let (s1, r1) = mpsc::channel();
134        let (s2, r2) = mpsc::channel();
135
136        let pool1 = pool.clone();
137        let t1 = thread::spawn(move || {
138            let conn = pool1.get().unwrap();
139            s1.send(()).unwrap();
140            r2.recv().unwrap();
141            drop(conn);
142        });
143
144        let pool2 = pool.clone();
145        let t2 = thread::spawn(move || {
146            let conn = pool2.get().unwrap();
147            s2.send(()).unwrap();
148            r1.recv().unwrap();
149            drop(conn);
150        });
151
152        t1.join().unwrap();
153        t2.join().unwrap();
154
155        pool.get().unwrap();
156        Ok(())
157    }
158
159    #[test]
160    fn test_file() -> Result<()> {
161        let manager = DuckdbConnectionManager::file(":memory:")?;
162        let pool = r2d2::Pool::builder().max_size(2).build(manager).unwrap();
163
164        let (s1, r1) = mpsc::channel();
165        let (s2, r2) = mpsc::channel();
166
167        let pool1 = pool.clone();
168        let t1 = thread::spawn(move || {
169            let conn = pool1.get().unwrap();
170            s1.send(()).unwrap();
171            r2.recv().unwrap();
172            drop(conn);
173        });
174
175        let pool2 = pool.clone();
176        let t2 = thread::spawn(move || {
177            let conn = pool2.get().unwrap();
178            s2.send(()).unwrap();
179            r1.recv().unwrap();
180            drop(conn);
181        });
182
183        t1.join().unwrap();
184        t2.join().unwrap();
185
186        pool.get().unwrap();
187        Ok(())
188    }
189
190    #[test]
191    fn test_is_valid() -> Result<()> {
192        let manager = DuckdbConnectionManager::file(":memory:")?;
193        let pool = r2d2::Pool::builder()
194            .max_size(1)
195            .test_on_check_out(true)
196            .build(manager)
197            .unwrap();
198
199        pool.get().unwrap();
200        Ok(())
201    }
202
203    #[test]
204    fn test_error_handling() -> Result<()> {
205        //! We specify a directory as a database. This is bound to fail.
206        let dir = TempDir::with_prefix("r2d2-duckdb").expect("Could not create temporary directory");
207        let dirpath = dir.path().to_str().unwrap();
208        assert!(DuckdbConnectionManager::file(dirpath).is_err());
209        Ok(())
210    }
211
212    #[test]
213    fn test_with_flags() -> Result<()> {
214        let config = Config::default()
215            .access_mode(crate::AccessMode::ReadWrite)?
216            .default_null_order(crate::DefaultNullOrder::NullsLast)?
217            .default_order(crate::DefaultOrder::Desc)?
218            .enable_external_access(true)?
219            .enable_object_cache(false)?
220            .max_memory("2GB")?
221            .threads(4)?;
222        let manager = DuckdbConnectionManager::file_with_flags(":memory:", config)?;
223        let pool = r2d2::Pool::builder().max_size(2).build(manager).unwrap();
224        let conn = pool.get().unwrap();
225        conn.execute_batch("CREATE TABLE foo(x Text)")?;
226
227        let mut stmt = conn.prepare("INSERT INTO foo(x) VALUES (?)")?;
228        stmt.execute([&"a"])?;
229        stmt.execute([&"b"])?;
230        stmt.execute([&"c"])?;
231        stmt.execute([Value::Null])?;
232
233        let val: Result<Vec<Option<String>>> = conn
234            .prepare("SELECT x FROM foo ORDER BY x")?
235            .query_and_then([], |row| row.get(0))?
236            .collect();
237        let val = val?;
238        let mut iter = val.iter();
239        assert_eq!(iter.next().unwrap().as_ref().unwrap(), "c");
240        assert_eq!(iter.next().unwrap().as_ref().unwrap(), "b");
241        assert_eq!(iter.next().unwrap().as_ref().unwrap(), "a");
242        assert!(iter.next().unwrap().is_none());
243        assert_eq!(iter.next(), None);
244
245        Ok(())
246    }
247}