1use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
20use sqlparser::dialect::SQLiteDialect;
21use sqlparser::keywords::Keyword;
22use sqlparser::tokenizer::{Token, Tokenizer};
23
24use crate::error::{Result, SQLRiteError};
25use crate::sql::CommandOutput;
26use crate::sql::db::database::Database;
27
28#[derive(Debug, Clone, PartialEq)]
36pub enum PragmaValue {
37 Number(String),
39 Identifier(String),
41 String(String),
43}
44
45#[derive(Debug, Clone, PartialEq)]
48pub struct PragmaStatement {
49 pub name: String,
50 pub value: Option<PragmaValue>,
51}
52
53pub fn try_parse_pragma(sql: &str) -> Result<Option<PragmaStatement>> {
59 let dialect = SQLiteDialect {};
60 let tokens = Tokenizer::new(&dialect, sql)
61 .tokenize()
62 .map_err(|e| SQLRiteError::General(format!("PRAGMA tokenize error: {e}")))?;
63
64 let mut iter = tokens
65 .into_iter()
66 .filter(|t| !matches!(t, Token::Whitespace(_)))
67 .peekable();
68
69 match iter.peek() {
72 Some(Token::Word(w)) if w.keyword == Keyword::PRAGMA => {
73 iter.next();
74 }
75 _ => return Ok(None),
76 }
77
78 let name = match iter.next() {
79 Some(Token::Word(w)) => w.value,
80 Some(other) => {
81 return Err(SQLRiteError::General(format!(
82 "PRAGMA: expected pragma name, got {other:?}"
83 )));
84 }
85 None => {
86 return Err(SQLRiteError::General(
87 "PRAGMA: missing pragma name".to_string(),
88 ));
89 }
90 };
91
92 let value = match iter.peek() {
93 None | Some(Token::SemiColon) => None,
94 Some(Token::Eq) => {
95 iter.next();
96 Some(read_pragma_value(&mut iter)?)
97 }
98 Some(Token::LParen) => {
99 iter.next();
100 let v = read_pragma_value(&mut iter)?;
101 match iter.next() {
102 Some(Token::RParen) => {}
103 Some(other) => {
104 return Err(SQLRiteError::General(format!(
105 "PRAGMA: expected ')' to close parenthesised value, got {other:?}"
106 )));
107 }
108 None => {
109 return Err(SQLRiteError::General(
110 "PRAGMA: expected ')' to close parenthesised value".to_string(),
111 ));
112 }
113 }
114 Some(v)
115 }
116 Some(other) => {
117 return Err(SQLRiteError::General(format!(
118 "PRAGMA: expected '=', '(', ';' or end of statement after name, got {other:?}"
119 )));
120 }
121 };
122
123 if matches!(iter.peek(), Some(Token::SemiColon)) {
127 iter.next();
128 }
129 if let Some(extra) = iter.next() {
130 return Err(SQLRiteError::General(format!(
131 "PRAGMA: unexpected trailing content {extra:?}"
132 )));
133 }
134
135 Ok(Some(PragmaStatement { name, value }))
136}
137
138fn read_pragma_value<I>(iter: &mut std::iter::Peekable<I>) -> Result<PragmaValue>
139where
140 I: Iterator<Item = Token>,
141{
142 let mut neg = false;
147 let first = iter.next().ok_or_else(|| {
148 SQLRiteError::General("PRAGMA: missing value after '=' or '('".to_string())
149 })?;
150
151 let tok = if matches!(first, Token::Minus) {
152 neg = true;
153 iter.next()
154 .ok_or_else(|| SQLRiteError::General("PRAGMA: missing value after '-'".to_string()))?
155 } else {
156 first
157 };
158
159 Ok(match tok {
160 Token::Number(s, _) => {
161 if neg {
162 PragmaValue::Number(format!("-{s}"))
163 } else {
164 PragmaValue::Number(s)
165 }
166 }
167 Token::SingleQuotedString(s) | Token::DoubleQuotedString(s) => {
168 if neg {
169 return Err(SQLRiteError::General(
170 "PRAGMA: unary '-' is only valid in front of a number".to_string(),
171 ));
172 }
173 PragmaValue::String(s)
174 }
175 Token::Word(w) => {
176 if neg {
177 return Err(SQLRiteError::General(
178 "PRAGMA: unary '-' is only valid in front of a number".to_string(),
179 ));
180 }
181 PragmaValue::Identifier(w.value)
182 }
183 other => {
184 return Err(SQLRiteError::General(format!(
185 "PRAGMA: unsupported value token {other:?}"
186 )));
187 }
188 })
189}
190
191pub fn execute_pragma(stmt: PragmaStatement, db: &mut Database) -> Result<CommandOutput> {
194 match stmt.name.to_ascii_lowercase().as_str() {
195 "auto_vacuum" => pragma_auto_vacuum(stmt.value, db),
196 other => Err(SQLRiteError::NotImplemented(format!(
197 "PRAGMA '{other}' is not supported"
198 ))),
199 }
200}
201
202fn pragma_auto_vacuum(value: Option<PragmaValue>, db: &mut Database) -> Result<CommandOutput> {
206 match value {
207 None => {
208 let mut t = PrintTable::new();
216 t.add_row(PrintRow::new(vec![PrintCell::new("auto_vacuum")]));
217 let cell_value = match db.auto_vacuum_threshold() {
218 Some(v) => format!("{v}"),
219 None => "OFF".to_string(),
220 };
221 t.add_row(PrintRow::new(vec![PrintCell::new(&cell_value)]));
222 Ok(CommandOutput {
223 status: "PRAGMA auto_vacuum executed. 1 row returned.".to_string(),
224 rendered: Some(t.to_string()),
225 })
226 }
227 Some(v) => {
228 let new_threshold = parse_auto_vacuum_target(&v)?;
229 db.set_auto_vacuum_threshold(new_threshold)?;
230 Ok(CommandOutput {
231 status: "PRAGMA auto_vacuum executed.".to_string(),
232 rendered: None,
233 })
234 }
235 }
236}
237
238fn parse_auto_vacuum_target(value: &PragmaValue) -> Result<Option<f32>> {
243 match value {
244 PragmaValue::Identifier(s) | PragmaValue::String(s) => {
245 match s.to_ascii_lowercase().as_str() {
246 "off" | "none" => Ok(None),
247 _ => Err(SQLRiteError::General(format!(
248 "PRAGMA auto_vacuum: expected a number in 0.0..=1.0 or OFF/NONE, got '{s}'"
249 ))),
250 }
251 }
252 PragmaValue::Number(s) => {
253 let f: f32 = s.parse().map_err(|_| {
254 SQLRiteError::General(format!("PRAGMA auto_vacuum: '{s}' is not a valid number"))
255 })?;
256 Ok(Some(f))
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn try_parse_pragma_returns_none_for_non_pragma() {
267 assert!(try_parse_pragma("SELECT 1;").unwrap().is_none());
268 assert!(
269 try_parse_pragma("CREATE TABLE t (id INTEGER);")
270 .unwrap()
271 .is_none()
272 );
273 assert!(try_parse_pragma("").unwrap().is_none());
275 assert!(try_parse_pragma(" \n\t ").unwrap().is_none());
276 assert!(try_parse_pragma("-- hello\n").unwrap().is_none());
277 }
278
279 #[test]
280 fn try_parse_pragma_read_form() {
281 let stmt = try_parse_pragma("PRAGMA auto_vacuum;").unwrap().unwrap();
282 assert_eq!(stmt.name, "auto_vacuum");
283 assert_eq!(stmt.value, None);
284
285 let stmt = try_parse_pragma(" PRAGMA auto_vacuum ").unwrap().unwrap();
287 assert_eq!(stmt.name, "auto_vacuum");
288 assert_eq!(stmt.value, None);
289
290 let stmt = try_parse_pragma("pragma auto_vacuum;").unwrap().unwrap();
292 assert_eq!(stmt.name, "auto_vacuum");
293 }
294
295 #[test]
296 fn try_parse_pragma_eq_number() {
297 let stmt = try_parse_pragma("PRAGMA auto_vacuum = 0.5;")
298 .unwrap()
299 .unwrap();
300 assert_eq!(stmt.name, "auto_vacuum");
301 assert_eq!(stmt.value, Some(PragmaValue::Number("0.5".to_string())));
302
303 let stmt = try_parse_pragma("PRAGMA auto_vacuum = 0;")
304 .unwrap()
305 .unwrap();
306 assert_eq!(stmt.value, Some(PragmaValue::Number("0".to_string())));
307
308 let stmt = try_parse_pragma("PRAGMA auto_vacuum = -0.1;")
311 .unwrap()
312 .unwrap();
313 assert_eq!(stmt.value, Some(PragmaValue::Number("-0.1".to_string())));
314 }
315
316 #[test]
317 fn try_parse_pragma_eq_identifier() {
318 let stmt = try_parse_pragma("PRAGMA auto_vacuum = OFF;")
319 .unwrap()
320 .unwrap();
321 assert_eq!(stmt.value, Some(PragmaValue::Identifier("OFF".to_string())));
322
323 let stmt = try_parse_pragma("PRAGMA auto_vacuum = none;")
324 .unwrap()
325 .unwrap();
326 assert_eq!(
327 stmt.value,
328 Some(PragmaValue::Identifier("none".to_string()))
329 );
330 }
331
332 #[test]
333 fn try_parse_pragma_eq_string() {
334 let stmt = try_parse_pragma("PRAGMA auto_vacuum = 'OFF';")
336 .unwrap()
337 .unwrap();
338 assert_eq!(stmt.value, Some(PragmaValue::String("OFF".to_string())));
339
340 let stmt = try_parse_pragma("PRAGMA auto_vacuum = \"NONE\";")
346 .unwrap()
347 .unwrap();
348 assert_eq!(
349 stmt.value,
350 Some(PragmaValue::Identifier("NONE".to_string()))
351 );
352 }
353
354 #[test]
355 fn try_parse_pragma_paren_form() {
356 let stmt = try_parse_pragma("PRAGMA auto_vacuum(0.5);")
357 .unwrap()
358 .unwrap();
359 assert_eq!(stmt.value, Some(PragmaValue::Number("0.5".to_string())));
360
361 let stmt = try_parse_pragma("PRAGMA auto_vacuum (OFF);")
362 .unwrap()
363 .unwrap();
364 assert_eq!(stmt.value, Some(PragmaValue::Identifier("OFF".to_string())));
365 }
366
367 #[test]
368 fn try_parse_pragma_rejects_malformed() {
369 assert!(try_parse_pragma("PRAGMA;").is_err());
370 assert!(try_parse_pragma("PRAGMA = 0.5;").is_err());
371 assert!(try_parse_pragma("PRAGMA auto_vacuum =;").is_err());
372 assert!(try_parse_pragma("PRAGMA auto_vacuum (0.5;").is_err());
373 assert!(try_parse_pragma("PRAGMA auto_vacuum; SELECT 1;").is_err());
375 assert!(try_parse_pragma("PRAGMA auto_vacuum = -'OFF';").is_err());
377 }
378
379 #[test]
380 fn parse_auto_vacuum_target_disables_on_off_or_none() {
381 for raw in ["OFF", "off", "Off", "NONE", "none"] {
382 assert_eq!(
383 parse_auto_vacuum_target(&PragmaValue::Identifier(raw.to_string())).unwrap(),
384 None
385 );
386 assert_eq!(
387 parse_auto_vacuum_target(&PragmaValue::String(raw.to_string())).unwrap(),
388 None
389 );
390 }
391 }
392
393 #[test]
394 fn parse_auto_vacuum_target_passes_numbers_through() {
395 assert_eq!(
396 parse_auto_vacuum_target(&PragmaValue::Number("0.5".to_string())).unwrap(),
397 Some(0.5_f32)
398 );
399 assert_eq!(
400 parse_auto_vacuum_target(&PragmaValue::Number("0".to_string())).unwrap(),
401 Some(0.0_f32)
402 );
403 assert_eq!(
406 parse_auto_vacuum_target(&PragmaValue::Number("1.5".to_string())).unwrap(),
407 Some(1.5_f32)
408 );
409 }
410
411 #[test]
412 fn parse_auto_vacuum_target_rejects_unknown_strings() {
413 let err =
414 parse_auto_vacuum_target(&PragmaValue::Identifier("WAL".to_string())).unwrap_err();
415 assert!(format!("{err}").contains("OFF/NONE"));
416 }
417
418 #[test]
419 fn execute_pragma_unknown_returns_not_implemented() {
420 let mut db = Database::new("t".to_string());
421 let err = execute_pragma(
422 PragmaStatement {
423 name: "journal_mode".to_string(),
424 value: None,
425 },
426 &mut db,
427 )
428 .unwrap_err();
429 assert!(matches!(err, SQLRiteError::NotImplemented(_)));
430 }
431
432 #[test]
433 fn execute_pragma_auto_vacuum_set_and_read() {
434 let mut db = Database::new("t".to_string());
435
436 let out = execute_pragma(
438 PragmaStatement {
439 name: "auto_vacuum".to_string(),
440 value: Some(PragmaValue::Number("0.5".to_string())),
441 },
442 &mut db,
443 )
444 .unwrap();
445 assert!(out.rendered.is_none());
446 assert_eq!(db.auto_vacuum_threshold(), Some(0.5));
447
448 let out = execute_pragma(
449 PragmaStatement {
450 name: "auto_vacuum".to_string(),
451 value: None,
452 },
453 &mut db,
454 )
455 .unwrap();
456 let rendered = out.rendered.expect("read form must render rows");
457 assert!(rendered.contains("auto_vacuum"));
458 assert!(rendered.contains("0.5"));
459
460 execute_pragma(
462 PragmaStatement {
463 name: "auto_vacuum".to_string(),
464 value: Some(PragmaValue::Identifier("OFF".to_string())),
465 },
466 &mut db,
467 )
468 .unwrap();
469 assert_eq!(db.auto_vacuum_threshold(), None);
470
471 let out = execute_pragma(
473 PragmaStatement {
474 name: "auto_vacuum".to_string(),
475 value: None,
476 },
477 &mut db,
478 )
479 .unwrap();
480 let rendered = out.rendered.unwrap();
481 assert!(rendered.contains("OFF"));
482 }
483
484 #[test]
485 fn execute_pragma_auto_vacuum_rejects_out_of_range() {
486 let mut db = Database::new("t".to_string());
487 let err = execute_pragma(
488 PragmaStatement {
489 name: "auto_vacuum".to_string(),
490 value: Some(PragmaValue::Number("1.5".to_string())),
491 },
492 &mut db,
493 )
494 .unwrap_err();
495 assert!(format!("{err}").contains("auto_vacuum_threshold"));
496
497 assert_eq!(db.auto_vacuum_threshold(), Some(0.25));
499 }
500
501 #[test]
502 fn execute_pragma_auto_vacuum_rejects_negative() {
503 let mut db = Database::new("t".to_string());
504 let err = execute_pragma(
505 PragmaStatement {
506 name: "auto_vacuum".to_string(),
507 value: Some(PragmaValue::Number("-0.1".to_string())),
508 },
509 &mut db,
510 )
511 .unwrap_err();
512 assert!(format!("{err}").contains("auto_vacuum_threshold"));
513 }
514}