1#![deny(warnings)]
2use crate::{Config, Connection, Error, Result};
38use std::{
39 path::Path,
40 sync::{Arc, Mutex},
41};
42
43#[cfg(feature = "vscalar")]
44use crate::vscalar::VScalar;
45#[cfg(feature = "vscalar")]
46use std::fmt::Debug;
47
48#[cfg(feature = "vtab")]
49use crate::vtab::VTab;
50
51pub struct DuckdbConnectionManager {
53 connection: Arc<Mutex<Connection>>,
54}
55
56impl DuckdbConnectionManager {
57 pub fn file<P: AsRef<Path>>(path: P) -> Result<Self> {
59 Ok(Self {
60 connection: Arc::new(Mutex::new(Connection::open(path)?)),
61 })
62 }
63 pub fn file_with_flags<P: AsRef<Path>>(path: P, config: Config) -> Result<Self> {
65 Ok(Self {
66 connection: Arc::new(Mutex::new(Connection::open_with_flags(path, config)?)),
67 })
68 }
69
70 pub fn memory() -> Result<Self> {
72 Ok(Self {
73 connection: Arc::new(Mutex::new(Connection::open_in_memory()?)),
74 })
75 }
76
77 pub fn memory_with_flags(config: Config) -> Result<Self> {
79 Ok(Self {
80 connection: Arc::new(Mutex::new(Connection::open_in_memory_with_flags(config)?)),
81 })
82 }
83
84 #[cfg(feature = "vtab")]
86 pub fn register_table_function<T: VTab>(&self, name: &str) -> Result<()> {
87 let conn = self.connection.lock().unwrap();
88 conn.register_table_function::<T>(name)
89 }
90
91 #[cfg(feature = "vscalar")]
93 pub fn register_scalar_function<S: VScalar>(&self, name: &str) -> Result<()>
94 where
95 S::State: Debug + Default,
96 {
97 let conn = self.connection.lock().unwrap();
98 conn.register_scalar_function::<S>(name)
99 }
100}
101
102impl r2d2::ManageConnection for DuckdbConnectionManager {
103 type Connection = Connection;
104 type Error = Error;
105
106 fn connect(&self) -> Result<Self::Connection, Self::Error> {
107 let conn = self.connection.lock().unwrap();
108 conn.try_clone()
109 }
110
111 fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
112 conn.execute_batch("")
113 }
114
115 fn has_broken(&self, _: &mut Self::Connection) -> bool {
116 false
117 }
118}
119
120#[cfg(test)]
121mod test {
122 use super::*;
123 use crate::types::Value;
124 use std::{sync::mpsc, thread};
125
126 use tempfile::TempDir;
127
128 #[test]
129 fn test_basic() -> Result<()> {
130 let manager = DuckdbConnectionManager::file(":memory:")?;
131 let pool = r2d2::Pool::builder().max_size(2).build(manager).unwrap();
132
133 let (s1, r1) = mpsc::channel();
134 let (s2, r2) = mpsc::channel();
135
136 let pool1 = pool.clone();
137 let t1 = thread::spawn(move || {
138 let conn = pool1.get().unwrap();
139 s1.send(()).unwrap();
140 r2.recv().unwrap();
141 drop(conn);
142 });
143
144 let pool2 = pool.clone();
145 let t2 = thread::spawn(move || {
146 let conn = pool2.get().unwrap();
147 s2.send(()).unwrap();
148 r1.recv().unwrap();
149 drop(conn);
150 });
151
152 t1.join().unwrap();
153 t2.join().unwrap();
154
155 pool.get().unwrap();
156 Ok(())
157 }
158
159 #[test]
160 fn test_file() -> Result<()> {
161 let manager = DuckdbConnectionManager::file(":memory:")?;
162 let pool = r2d2::Pool::builder().max_size(2).build(manager).unwrap();
163
164 let (s1, r1) = mpsc::channel();
165 let (s2, r2) = mpsc::channel();
166
167 let pool1 = pool.clone();
168 let t1 = thread::spawn(move || {
169 let conn = pool1.get().unwrap();
170 s1.send(()).unwrap();
171 r2.recv().unwrap();
172 drop(conn);
173 });
174
175 let pool2 = pool.clone();
176 let t2 = thread::spawn(move || {
177 let conn = pool2.get().unwrap();
178 s2.send(()).unwrap();
179 r1.recv().unwrap();
180 drop(conn);
181 });
182
183 t1.join().unwrap();
184 t2.join().unwrap();
185
186 pool.get().unwrap();
187 Ok(())
188 }
189
190 #[test]
191 fn test_is_valid() -> Result<()> {
192 let manager = DuckdbConnectionManager::file(":memory:")?;
193 let pool = r2d2::Pool::builder()
194 .max_size(1)
195 .test_on_check_out(true)
196 .build(manager)
197 .unwrap();
198
199 pool.get().unwrap();
200 Ok(())
201 }
202
203 #[test]
204 fn test_error_handling() -> Result<()> {
205 let dir = TempDir::with_prefix("r2d2-duckdb").expect("Could not create temporary directory");
207 let dirpath = dir.path().to_str().unwrap();
208 assert!(DuckdbConnectionManager::file(dirpath).is_err());
209 Ok(())
210 }
211
212 #[test]
213 fn test_with_flags() -> Result<()> {
214 let config = Config::default()
215 .access_mode(crate::AccessMode::ReadWrite)?
216 .default_null_order(crate::DefaultNullOrder::NullsLast)?
217 .default_order(crate::DefaultOrder::Desc)?
218 .enable_external_access(true)?
219 .enable_object_cache(false)?
220 .max_memory("2GB")?
221 .threads(4)?;
222 let manager = DuckdbConnectionManager::file_with_flags(":memory:", config)?;
223 let pool = r2d2::Pool::builder().max_size(2).build(manager).unwrap();
224 let conn = pool.get().unwrap();
225 conn.execute_batch("CREATE TABLE foo(x Text)")?;
226
227 let mut stmt = conn.prepare("INSERT INTO foo(x) VALUES (?)")?;
228 stmt.execute([&"a"])?;
229 stmt.execute([&"b"])?;
230 stmt.execute([&"c"])?;
231 stmt.execute([Value::Null])?;
232
233 let val: Result<Vec<Option<String>>> = conn
234 .prepare("SELECT x FROM foo ORDER BY x")?
235 .query_and_then([], |row| row.get(0))?
236 .collect();
237 let val = val?;
238 let mut iter = val.iter();
239 assert_eq!(iter.next().unwrap().as_ref().unwrap(), "c");
240 assert_eq!(iter.next().unwrap().as_ref().unwrap(), "b");
241 assert_eq!(iter.next().unwrap().as_ref().unwrap(), "a");
242 assert!(iter.next().unwrap().is_none());
243 assert_eq!(iter.next(), None);
244
245 Ok(())
246 }
247}