Skip to main content

systemprompt_database/admin/
admin_sql.rs

1//! Parser/validator for admin-supplied SQL strings.
2//!
3//! Two parse modes are exposed:
4//! - [`AdminSql::parse_readonly`] — only `SELECT`/`WITH`/`EXPLAIN`/`SHOW`/
5//!   `TABLE`/`VALUES` queries with no forbidden keywords.
6//! - [`AdminSql::parse_unrestricted`] — single statement, otherwise free-form.
7
8use thiserror::Error;
9
10pub const DEFAULT_READONLY_ROW_LIMIT: usize = 1000;
11
12const READONLY_PREFIXES: &[&str] = &["select", "with", "explain", "show", "table", "values"];
13
14const FORBIDDEN_KEYWORDS: &[&str] = &[
15    " drop ",
16    " delete ",
17    " insert ",
18    " update ",
19    " alter ",
20    " create ",
21    " truncate ",
22    " grant ",
23    " revoke ",
24    " copy ",
25    " vacuum ",
26    " call ",
27    " lock ",
28    " set ",
29    " reset ",
30    " rename ",
31];
32
33#[derive(Debug, Clone, Copy, Error)]
34pub enum AdminSqlError {
35    #[error("SQL query is empty")]
36    Empty,
37    #[error("SQL query contains multiple statements; only one is allowed")]
38    MultipleStatements,
39    #[error("SQL query must begin with SELECT, WITH, EXPLAIN, SHOW, TABLE, or VALUES")]
40    NotReadOnly,
41    #[error("SQL query contains forbidden keyword for read-only mode")]
42    ForbiddenKeyword,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct AdminSql(String);
47
48impl AdminSql {
49    pub fn parse_readonly(raw: &str) -> Result<Self, AdminSqlError> {
50        let stripped = strip_comments(raw);
51        let trimmed = stripped.trim();
52        if trimmed.is_empty() {
53            return Err(AdminSqlError::Empty);
54        }
55
56        let without_terminator = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
57        if without_terminator.contains(';') {
58            return Err(AdminSqlError::MultipleStatements);
59        }
60
61        let lower = without_terminator.to_lowercase();
62        if !READONLY_PREFIXES
63            .iter()
64            .any(|p| starts_with_word(&lower, p))
65        {
66            return Err(AdminSqlError::NotReadOnly);
67        }
68
69        let padded = format!(" {lower} ");
70        if FORBIDDEN_KEYWORDS.iter().any(|kw| padded.contains(kw)) {
71            return Err(AdminSqlError::ForbiddenKeyword);
72        }
73
74        Ok(Self(without_terminator.to_string()))
75    }
76
77    pub fn parse_unrestricted(raw: &str) -> Result<Self, AdminSqlError> {
78        let stripped = strip_comments(raw);
79        let trimmed = stripped.trim();
80        if trimmed.is_empty() {
81            return Err(AdminSqlError::Empty);
82        }
83
84        let without_terminator = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
85        if without_terminator.contains(';') {
86            return Err(AdminSqlError::MultipleStatements);
87        }
88
89        Ok(Self(without_terminator.to_string()))
90    }
91
92    pub fn as_str(&self) -> &str {
93        &self.0
94    }
95}
96
97fn strip_comments(raw: &str) -> String {
98    let mut out = String::with_capacity(raw.len());
99    let mut chars = raw.chars().peekable();
100    while let Some(c) = chars.next() {
101        if c == '-' && chars.peek() == Some(&'-') {
102            for nc in chars.by_ref() {
103                if nc == '\n' {
104                    out.push('\n');
105                    break;
106                }
107            }
108            continue;
109        }
110        if c == '/' && chars.peek() == Some(&'*') {
111            chars.next();
112            let mut prev = '\0';
113            for nc in chars.by_ref() {
114                if prev == '*' && nc == '/' {
115                    break;
116                }
117                prev = nc;
118            }
119            continue;
120        }
121        out.push(c);
122    }
123    out
124}
125
126fn starts_with_word(haystack: &str, needle: &str) -> bool {
127    if !haystack.starts_with(needle) {
128        return false;
129    }
130    haystack[needle.len()..]
131        .chars()
132        .next()
133        .is_none_or(|c| c.is_whitespace() || c == '(' || c == ';')
134}