1use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
20use sqlparser::dialect::SQLiteDialect;
21use sqlparser::keywords::Keyword;
22use sqlparser::tokenizer::{Token, Tokenizer};
23
24use crate::error::{Result, SQLRiteError};
25use crate::mvcc::JournalMode;
26use crate::sql::CommandOutput;
27use crate::sql::db::database::Database;
28
29#[derive(Debug, Clone, PartialEq)]
37pub enum PragmaValue {
38 Number(String),
40 Identifier(String),
42 String(String),
44}
45
46#[derive(Debug, Clone, PartialEq)]
49pub struct PragmaStatement {
50 pub name: String,
51 pub value: Option<PragmaValue>,
52}
53
54pub fn try_parse_pragma(sql: &str) -> Result<Option<PragmaStatement>> {
60 let dialect = SQLiteDialect {};
61 let tokens = Tokenizer::new(&dialect, sql)
62 .tokenize()
63 .map_err(|e| SQLRiteError::General(format!("PRAGMA tokenize error: {e}")))?;
64
65 let mut iter = tokens
66 .into_iter()
67 .filter(|t| !matches!(t, Token::Whitespace(_)))
68 .peekable();
69
70 match iter.peek() {
73 Some(Token::Word(w)) if w.keyword == Keyword::PRAGMA => {
74 iter.next();
75 }
76 _ => return Ok(None),
77 }
78
79 let name = match iter.next() {
80 Some(Token::Word(w)) => w.value,
81 Some(other) => {
82 return Err(SQLRiteError::General(format!(
83 "PRAGMA: expected pragma name, got {other:?}"
84 )));
85 }
86 None => {
87 return Err(SQLRiteError::General(
88 "PRAGMA: missing pragma name".to_string(),
89 ));
90 }
91 };
92
93 let value = match iter.peek() {
94 None | Some(Token::SemiColon) => None,
95 Some(Token::Eq) => {
96 iter.next();
97 Some(read_pragma_value(&mut iter)?)
98 }
99 Some(Token::LParen) => {
100 iter.next();
101 let v = read_pragma_value(&mut iter)?;
102 match iter.next() {
103 Some(Token::RParen) => {}
104 Some(other) => {
105 return Err(SQLRiteError::General(format!(
106 "PRAGMA: expected ')' to close parenthesised value, got {other:?}"
107 )));
108 }
109 None => {
110 return Err(SQLRiteError::General(
111 "PRAGMA: expected ')' to close parenthesised value".to_string(),
112 ));
113 }
114 }
115 Some(v)
116 }
117 Some(other) => {
118 return Err(SQLRiteError::General(format!(
119 "PRAGMA: expected '=', '(', ';' or end of statement after name, got {other:?}"
120 )));
121 }
122 };
123
124 if matches!(iter.peek(), Some(Token::SemiColon)) {
128 iter.next();
129 }
130 if let Some(extra) = iter.next() {
131 return Err(SQLRiteError::General(format!(
132 "PRAGMA: unexpected trailing content {extra:?}"
133 )));
134 }
135
136 Ok(Some(PragmaStatement { name, value }))
137}
138
139fn read_pragma_value<I>(iter: &mut std::iter::Peekable<I>) -> Result<PragmaValue>
140where
141 I: Iterator<Item = Token>,
142{
143 let mut neg = false;
148 let first = iter.next().ok_or_else(|| {
149 SQLRiteError::General("PRAGMA: missing value after '=' or '('".to_string())
150 })?;
151
152 let tok = if matches!(first, Token::Minus) {
153 neg = true;
154 iter.next()
155 .ok_or_else(|| SQLRiteError::General("PRAGMA: missing value after '-'".to_string()))?
156 } else {
157 first
158 };
159
160 Ok(match tok {
161 Token::Number(s, _) => {
162 if neg {
163 PragmaValue::Number(format!("-{s}"))
164 } else {
165 PragmaValue::Number(s)
166 }
167 }
168 Token::SingleQuotedString(s) | Token::DoubleQuotedString(s) => {
169 if neg {
170 return Err(SQLRiteError::General(
171 "PRAGMA: unary '-' is only valid in front of a number".to_string(),
172 ));
173 }
174 PragmaValue::String(s)
175 }
176 Token::Word(w) => {
177 if neg {
178 return Err(SQLRiteError::General(
179 "PRAGMA: unary '-' is only valid in front of a number".to_string(),
180 ));
181 }
182 PragmaValue::Identifier(w.value)
183 }
184 other => {
185 return Err(SQLRiteError::General(format!(
186 "PRAGMA: unsupported value token {other:?}"
187 )));
188 }
189 })
190}
191
192pub fn execute_pragma(stmt: PragmaStatement, db: &mut Database) -> Result<CommandOutput> {
195 match stmt.name.to_ascii_lowercase().as_str() {
196 "auto_vacuum" => pragma_auto_vacuum(stmt.value, db),
197 "journal_mode" => pragma_journal_mode(stmt.value, db),
198 other => Err(SQLRiteError::NotImplemented(format!(
199 "PRAGMA '{other}' is not supported"
200 ))),
201 }
202}
203
204fn pragma_journal_mode(value: Option<PragmaValue>, db: &mut Database) -> Result<CommandOutput> {
210 match value {
211 None => render_journal_mode(db.journal_mode()),
212 Some(v) => {
213 let target = parse_journal_mode_target(&v)?;
214 db.set_journal_mode(target)?;
215 render_journal_mode(db.journal_mode())
218 }
219 }
220}
221
222fn render_journal_mode(mode: JournalMode) -> Result<CommandOutput> {
223 let mut t = PrintTable::new();
224 t.add_row(PrintRow::new(vec![PrintCell::new("journal_mode")]));
225 t.add_row(PrintRow::new(vec![PrintCell::new(mode.as_str())]));
226 Ok(CommandOutput {
227 status: "PRAGMA journal_mode executed. 1 row returned.".to_string(),
228 rendered: Some(t.to_string()),
229 })
230}
231
232fn parse_journal_mode_target(value: &PragmaValue) -> Result<JournalMode> {
233 let s = match value {
234 PragmaValue::Identifier(s) | PragmaValue::String(s) => s.as_str(),
235 PragmaValue::Number(s) => {
236 return Err(SQLRiteError::General(format!(
237 "PRAGMA journal_mode: expected 'wal' or 'mvcc', got numeric '{s}'"
238 )));
239 }
240 };
241 JournalMode::from_str_lossless(s).ok_or_else(|| {
242 SQLRiteError::General(format!(
243 "PRAGMA journal_mode: unknown mode '{s}' (supported: 'wal', 'mvcc')"
244 ))
245 })
246}
247
248fn pragma_auto_vacuum(value: Option<PragmaValue>, db: &mut Database) -> Result<CommandOutput> {
252 match value {
253 None => {
254 let mut t = PrintTable::new();
262 t.add_row(PrintRow::new(vec![PrintCell::new("auto_vacuum")]));
263 let cell_value = match db.auto_vacuum_threshold() {
264 Some(v) => format!("{v}"),
265 None => "OFF".to_string(),
266 };
267 t.add_row(PrintRow::new(vec![PrintCell::new(&cell_value)]));
268 Ok(CommandOutput {
269 status: "PRAGMA auto_vacuum executed. 1 row returned.".to_string(),
270 rendered: Some(t.to_string()),
271 })
272 }
273 Some(v) => {
274 let new_threshold = parse_auto_vacuum_target(&v)?;
275 db.set_auto_vacuum_threshold(new_threshold)?;
276 Ok(CommandOutput {
277 status: "PRAGMA auto_vacuum executed.".to_string(),
278 rendered: None,
279 })
280 }
281 }
282}
283
284fn parse_auto_vacuum_target(value: &PragmaValue) -> Result<Option<f32>> {
289 match value {
290 PragmaValue::Identifier(s) | PragmaValue::String(s) => {
291 match s.to_ascii_lowercase().as_str() {
292 "off" | "none" => Ok(None),
293 _ => Err(SQLRiteError::General(format!(
294 "PRAGMA auto_vacuum: expected a number in 0.0..=1.0 or OFF/NONE, got '{s}'"
295 ))),
296 }
297 }
298 PragmaValue::Number(s) => {
299 let f: f32 = s.parse().map_err(|_| {
300 SQLRiteError::General(format!("PRAGMA auto_vacuum: '{s}' is not a valid number"))
301 })?;
302 Ok(Some(f))
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn try_parse_pragma_returns_none_for_non_pragma() {
313 assert!(try_parse_pragma("SELECT 1;").unwrap().is_none());
314 assert!(
315 try_parse_pragma("CREATE TABLE t (id INTEGER);")
316 .unwrap()
317 .is_none()
318 );
319 assert!(try_parse_pragma("").unwrap().is_none());
321 assert!(try_parse_pragma(" \n\t ").unwrap().is_none());
322 assert!(try_parse_pragma("-- hello\n").unwrap().is_none());
323 }
324
325 #[test]
326 fn try_parse_pragma_read_form() {
327 let stmt = try_parse_pragma("PRAGMA auto_vacuum;").unwrap().unwrap();
328 assert_eq!(stmt.name, "auto_vacuum");
329 assert_eq!(stmt.value, None);
330
331 let stmt = try_parse_pragma(" PRAGMA auto_vacuum ").unwrap().unwrap();
333 assert_eq!(stmt.name, "auto_vacuum");
334 assert_eq!(stmt.value, None);
335
336 let stmt = try_parse_pragma("pragma auto_vacuum;").unwrap().unwrap();
338 assert_eq!(stmt.name, "auto_vacuum");
339 }
340
341 #[test]
342 fn try_parse_pragma_eq_number() {
343 let stmt = try_parse_pragma("PRAGMA auto_vacuum = 0.5;")
344 .unwrap()
345 .unwrap();
346 assert_eq!(stmt.name, "auto_vacuum");
347 assert_eq!(stmt.value, Some(PragmaValue::Number("0.5".to_string())));
348
349 let stmt = try_parse_pragma("PRAGMA auto_vacuum = 0;")
350 .unwrap()
351 .unwrap();
352 assert_eq!(stmt.value, Some(PragmaValue::Number("0".to_string())));
353
354 let stmt = try_parse_pragma("PRAGMA auto_vacuum = -0.1;")
357 .unwrap()
358 .unwrap();
359 assert_eq!(stmt.value, Some(PragmaValue::Number("-0.1".to_string())));
360 }
361
362 #[test]
363 fn try_parse_pragma_eq_identifier() {
364 let stmt = try_parse_pragma("PRAGMA auto_vacuum = OFF;")
365 .unwrap()
366 .unwrap();
367 assert_eq!(stmt.value, Some(PragmaValue::Identifier("OFF".to_string())));
368
369 let stmt = try_parse_pragma("PRAGMA auto_vacuum = none;")
370 .unwrap()
371 .unwrap();
372 assert_eq!(
373 stmt.value,
374 Some(PragmaValue::Identifier("none".to_string()))
375 );
376 }
377
378 #[test]
379 fn try_parse_pragma_eq_string() {
380 let stmt = try_parse_pragma("PRAGMA auto_vacuum = 'OFF';")
382 .unwrap()
383 .unwrap();
384 assert_eq!(stmt.value, Some(PragmaValue::String("OFF".to_string())));
385
386 let stmt = try_parse_pragma("PRAGMA auto_vacuum = \"NONE\";")
392 .unwrap()
393 .unwrap();
394 assert_eq!(
395 stmt.value,
396 Some(PragmaValue::Identifier("NONE".to_string()))
397 );
398 }
399
400 #[test]
401 fn try_parse_pragma_paren_form() {
402 let stmt = try_parse_pragma("PRAGMA auto_vacuum(0.5);")
403 .unwrap()
404 .unwrap();
405 assert_eq!(stmt.value, Some(PragmaValue::Number("0.5".to_string())));
406
407 let stmt = try_parse_pragma("PRAGMA auto_vacuum (OFF);")
408 .unwrap()
409 .unwrap();
410 assert_eq!(stmt.value, Some(PragmaValue::Identifier("OFF".to_string())));
411 }
412
413 #[test]
414 fn try_parse_pragma_rejects_malformed() {
415 assert!(try_parse_pragma("PRAGMA;").is_err());
416 assert!(try_parse_pragma("PRAGMA = 0.5;").is_err());
417 assert!(try_parse_pragma("PRAGMA auto_vacuum =;").is_err());
418 assert!(try_parse_pragma("PRAGMA auto_vacuum (0.5;").is_err());
419 assert!(try_parse_pragma("PRAGMA auto_vacuum; SELECT 1;").is_err());
421 assert!(try_parse_pragma("PRAGMA auto_vacuum = -'OFF';").is_err());
423 }
424
425 #[test]
426 fn parse_auto_vacuum_target_disables_on_off_or_none() {
427 for raw in ["OFF", "off", "Off", "NONE", "none"] {
428 assert_eq!(
429 parse_auto_vacuum_target(&PragmaValue::Identifier(raw.to_string())).unwrap(),
430 None
431 );
432 assert_eq!(
433 parse_auto_vacuum_target(&PragmaValue::String(raw.to_string())).unwrap(),
434 None
435 );
436 }
437 }
438
439 #[test]
440 fn parse_auto_vacuum_target_passes_numbers_through() {
441 assert_eq!(
442 parse_auto_vacuum_target(&PragmaValue::Number("0.5".to_string())).unwrap(),
443 Some(0.5_f32)
444 );
445 assert_eq!(
446 parse_auto_vacuum_target(&PragmaValue::Number("0".to_string())).unwrap(),
447 Some(0.0_f32)
448 );
449 assert_eq!(
452 parse_auto_vacuum_target(&PragmaValue::Number("1.5".to_string())).unwrap(),
453 Some(1.5_f32)
454 );
455 }
456
457 #[test]
458 fn parse_auto_vacuum_target_rejects_unknown_strings() {
459 let err =
460 parse_auto_vacuum_target(&PragmaValue::Identifier("WAL".to_string())).unwrap_err();
461 assert!(format!("{err}").contains("OFF/NONE"));
462 }
463
464 #[test]
465 fn execute_pragma_unknown_returns_not_implemented() {
466 let mut db = Database::new("t".to_string());
469 let err = execute_pragma(
470 PragmaStatement {
471 name: "synchronous".to_string(),
472 value: None,
473 },
474 &mut db,
475 )
476 .unwrap_err();
477 assert!(matches!(err, SQLRiteError::NotImplemented(_)));
478 }
479
480 #[test]
481 fn execute_pragma_auto_vacuum_set_and_read() {
482 let mut db = Database::new("t".to_string());
483
484 let out = execute_pragma(
486 PragmaStatement {
487 name: "auto_vacuum".to_string(),
488 value: Some(PragmaValue::Number("0.5".to_string())),
489 },
490 &mut db,
491 )
492 .unwrap();
493 assert!(out.rendered.is_none());
494 assert_eq!(db.auto_vacuum_threshold(), Some(0.5));
495
496 let out = execute_pragma(
497 PragmaStatement {
498 name: "auto_vacuum".to_string(),
499 value: None,
500 },
501 &mut db,
502 )
503 .unwrap();
504 let rendered = out.rendered.expect("read form must render rows");
505 assert!(rendered.contains("auto_vacuum"));
506 assert!(rendered.contains("0.5"));
507
508 execute_pragma(
510 PragmaStatement {
511 name: "auto_vacuum".to_string(),
512 value: Some(PragmaValue::Identifier("OFF".to_string())),
513 },
514 &mut db,
515 )
516 .unwrap();
517 assert_eq!(db.auto_vacuum_threshold(), None);
518
519 let out = execute_pragma(
521 PragmaStatement {
522 name: "auto_vacuum".to_string(),
523 value: None,
524 },
525 &mut db,
526 )
527 .unwrap();
528 let rendered = out.rendered.unwrap();
529 assert!(rendered.contains("OFF"));
530 }
531
532 #[test]
533 fn execute_pragma_auto_vacuum_rejects_out_of_range() {
534 let mut db = Database::new("t".to_string());
535 let err = execute_pragma(
536 PragmaStatement {
537 name: "auto_vacuum".to_string(),
538 value: Some(PragmaValue::Number("1.5".to_string())),
539 },
540 &mut db,
541 )
542 .unwrap_err();
543 assert!(format!("{err}").contains("auto_vacuum_threshold"));
544
545 assert_eq!(db.auto_vacuum_threshold(), Some(0.25));
547 }
548
549 #[test]
550 fn execute_pragma_auto_vacuum_rejects_negative() {
551 let mut db = Database::new("t".to_string());
552 let err = execute_pragma(
553 PragmaStatement {
554 name: "auto_vacuum".to_string(),
555 value: Some(PragmaValue::Number("-0.1".to_string())),
556 },
557 &mut db,
558 )
559 .unwrap_err();
560 assert!(format!("{err}").contains("auto_vacuum_threshold"));
561 }
562}