1use std::ops::Deref;
4
5use crate::{
6 error::Error,
7 ffi,
8 types::{ToSql, ToSqlOutput, ValueRef},
9 Connection, DatabaseName, Result, Row,
10};
11
12pub struct Sql {
13 buf: String,
14}
15
16impl Sql {
17 pub fn new() -> Sql {
18 Sql { buf: String::new() }
19 }
20
21 pub fn push_pragma(&mut self, schema_name: Option<DatabaseName<'_>>, pragma_name: &str) -> Result<()> {
22 self.push_keyword("PRAGMA")?;
23 self.push_space();
24 if let Some(schema_name) = schema_name {
25 self.push_schema_name(schema_name);
26 self.push_dot();
27 }
28 self.push_keyword(pragma_name)
29 }
30
31 pub fn push_keyword(&mut self, keyword: &str) -> Result<()> {
32 if !keyword.is_empty() && is_identifier(keyword) {
33 self.buf.push_str(keyword);
34 Ok(())
35 } else {
36 Err(Error::DuckDBFailure(
37 ffi::Error::new(ffi::DuckDBError),
38 Some(format!("Invalid keyword \"{keyword}\"")),
39 ))
40 }
41 }
42
43 pub fn push_schema_name(&mut self, schema_name: DatabaseName<'_>) {
44 match schema_name {
45 DatabaseName::Main => self.buf.push_str("main"),
46 DatabaseName::Temp => self.buf.push_str("temp"),
47 DatabaseName::Attached(s) => self.push_identifier(s),
48 };
49 }
50
51 pub fn push_identifier(&mut self, s: &str) {
52 if is_identifier(s) {
53 self.buf.push_str(s);
54 } else {
55 self.wrap_and_escape(s, '"');
56 }
57 }
58
59 pub fn push_value(&mut self, value: &dyn ToSql) -> Result<()> {
60 let value = value.to_sql()?;
61 let value = match value {
62 ToSqlOutput::Borrowed(v) => v,
63 ToSqlOutput::Owned(ref v) => ValueRef::from(v),
64 };
65 match value {
66 ValueRef::BigInt(i) => {
67 self.push_int(i);
68 }
69 ValueRef::Double(r) => {
70 self.push_real(r);
71 }
72 ValueRef::Text(s) => {
73 let s = std::str::from_utf8(s)?;
74 self.push_string_literal(s);
75 }
76 _ => {
77 return Err(Error::DuckDBFailure(
78 ffi::Error::new(ffi::DuckDBError),
79 Some(format!("Unsupported value \"{value:?}\"")),
80 ));
81 }
82 };
83 Ok(())
84 }
85
86 pub fn push_string_literal(&mut self, s: &str) {
87 self.wrap_and_escape(s, '\'');
88 }
89
90 pub fn push_int(&mut self, i: i64) {
91 self.buf.push_str(&i.to_string());
92 }
93
94 pub fn push_real(&mut self, f: f64) {
95 self.buf.push_str(&f.to_string());
96 }
97
98 pub fn push_space(&mut self) {
99 self.buf.push(' ');
100 }
101
102 pub fn push_dot(&mut self) {
103 self.buf.push('.');
104 }
105
106 pub fn push_equal_sign(&mut self) {
107 self.buf.push('=');
108 }
109
110 pub fn open_brace(&mut self) {
111 self.buf.push('(');
112 }
113
114 pub fn close_brace(&mut self) {
115 self.buf.push(')');
116 }
117
118 pub fn as_str(&self) -> &str {
119 &self.buf
120 }
121
122 fn wrap_and_escape(&mut self, s: &str, quote: char) {
123 self.buf.push(quote);
124 let chars = s.chars();
125 for ch in chars {
126 if ch == quote {
128 self.buf.push(ch);
129 }
130 self.buf.push(ch)
131 }
132 self.buf.push(quote);
133 }
134}
135
136impl Deref for Sql {
137 type Target = str;
138
139 fn deref(&self) -> &str {
140 self.as_str()
141 }
142}
143
144impl Connection {
145 pub fn pragma_query_value<T, F>(&self, schema_name: Option<DatabaseName<'_>>, pragma_name: &str, f: F) -> Result<T>
153 where
154 F: FnOnce(&Row<'_>) -> Result<T>,
155 {
156 let mut query = Sql::new();
157 query.push_pragma(schema_name, pragma_name)?;
158 self.query_row(&query, [], f)
159 }
160
161 pub fn pragma_query<F>(&self, schema_name: Option<DatabaseName<'_>>, pragma_name: &str, mut f: F) -> Result<()>
166 where
167 F: FnMut(&Row<'_>) -> Result<()>,
168 {
169 let mut query = Sql::new();
170 query.push_pragma(schema_name, pragma_name)?;
171 let mut stmt = self.prepare(&query)?;
172 let mut rows = stmt.query([])?;
173 while let Some(result_row) = rows.next()? {
174 f(result_row)?;
175 }
176 Ok(())
177 }
178
179 pub fn pragma<F>(
189 &self,
190 schema_name: Option<DatabaseName<'_>>,
191 pragma_name: &str,
192 pragma_value: &dyn ToSql,
193 mut f: F,
194 ) -> Result<()>
195 where
196 F: FnMut(&Row<'_>) -> Result<()>,
197 {
198 let mut sql = Sql::new();
199 sql.push_pragma(schema_name, pragma_name)?;
200 sql.open_brace();
204 sql.push_value(pragma_value)?;
205 sql.close_brace();
206 let mut stmt = self.prepare(&sql)?;
207 let mut rows = stmt.query([])?;
208 while let Some(result_row) = rows.next()? {
209 let row = result_row;
210 f(row)?;
211 }
212 Ok(())
213 }
214
215 pub fn pragma_update(
220 &self,
221 schema_name: Option<DatabaseName<'_>>,
222 pragma_name: &str,
223 pragma_value: &dyn ToSql,
224 ) -> Result<()> {
225 let mut sql = Sql::new();
226 sql.push_pragma(schema_name, pragma_name)?;
227 sql.push_equal_sign();
231 sql.push_value(pragma_value)?;
232 self.execute_batch(&sql)
233 }
234
235 pub fn pragma_update_and_check<F, T>(
239 &self,
240 schema_name: Option<DatabaseName<'_>>,
241 pragma_name: &str,
242 pragma_value: &dyn ToSql,
243 f: F,
244 ) -> Result<T>
245 where
246 F: FnOnce(&Row<'_>) -> Result<T>,
247 {
248 let mut sql = Sql::new();
249 sql.push_pragma(schema_name, pragma_name)?;
250 sql.push_equal_sign();
254 sql.push_value(pragma_value)?;
255 self.query_row(&sql, [], f)
256 }
257}
258
259fn is_identifier(s: &str) -> bool {
260 let chars = s.char_indices();
261 for (i, ch) in chars {
262 if i == 0 {
263 if !is_identifier_start(ch) {
264 return false;
265 }
266 } else if !is_identifier_continue(ch) {
267 return false;
268 }
269 }
270 true
271}
272
273fn is_identifier_start(c: char) -> bool {
274 c.is_ascii_alphabetic() || c == '_' || c > '\x7F'
275}
276
277fn is_identifier_continue(c: char) -> bool {
278 c == '$' || c.is_ascii_alphanumeric() || c == '_' || c > '\x7F'
279}
280
281#[cfg(test)]
282mod test {
283 use super::Sql;
284 use crate::{pragma, Connection, DatabaseName, Result};
285
286 #[test]
287 fn pragma_query_value() -> Result<()> {
288 let db = Connection::open_in_memory()?;
289 let version: String = db.pragma_query_value(None, "version", |row| row.get(0))?;
290 assert!(!version.is_empty());
291 Ok(())
292 }
293
294 #[test]
295 #[ignore = "not supported"]
296 fn pragma_query_with_schema() -> Result<()> {
297 let db = Connection::open_in_memory()?;
298 let mut version = "".to_string();
299 db.pragma_query(Some(DatabaseName::Main), "version", |row| {
300 version = row.get(0)?;
301 Ok(())
302 })?;
303 assert!(!version.is_empty());
304 Ok(())
305 }
306
307 #[test]
308 fn pragma() -> Result<()> {
309 let db = Connection::open_in_memory()?;
310 let mut columns = Vec::new();
311 db.pragma(None, "table_info", &"sqlite_master", |row| {
312 let column: String = row.get(1)?;
313 columns.push(column);
314 Ok(())
315 })?;
316 assert_eq!(5, columns.len());
317 Ok(())
318 }
319
320 #[test]
321 fn pragma_update() -> Result<()> {
322 let db = Connection::open_in_memory()?;
323 db.pragma_update(None, "explain_output", &"PHYSICAL_ONLY")
324 }
325
326 #[test]
327 #[ignore = "don't support query pragma"]
328 fn test_pragma_update_and_check() -> Result<()> {
329 let db = Connection::open_in_memory()?;
330 let journal_mode: String =
331 db.pragma_update_and_check(None, "explain_output", &"OPTIMIZED_ONLY", |row| row.get(0))?;
332 assert_eq!("OPTIMIZED_ONLY", &journal_mode);
333 Ok(())
334 }
335
336 #[test]
337 fn is_identifier() {
338 assert!(pragma::is_identifier("full"));
339 assert!(pragma::is_identifier("r2d2"));
340 assert!(!pragma::is_identifier("sp ce"));
341 assert!(!pragma::is_identifier("semi;colon"));
342 }
343
344 #[test]
345 fn double_quote() {
346 let mut sql = Sql::new();
347 sql.push_schema_name(DatabaseName::Attached(r#"schema";--"#));
348 assert_eq!(r#""schema"";--""#, sql.as_str());
349 }
350
351 #[test]
352 fn wrap_and_escape() {
353 let mut sql = Sql::new();
354 sql.push_string_literal("value'; --");
355 assert_eq!("'value''; --'", sql.as_str());
356 }
357
358 #[test]
359 #[ignore]
360 fn test_locking_mode() -> Result<()> {
361 let db = Connection::open_in_memory()?;
362 let r = db.pragma_update(None, "locking_mode", &"exclusive");
363 if cfg!(feature = "extra_check") {
364 r.unwrap_err();
365 } else {
366 r?;
367 }
368 Ok(())
369 }
370}