Skip to main content

systemprompt_database/admin/
admin_sql.rs

1use thiserror::Error;
2
3pub const DEFAULT_READONLY_ROW_LIMIT: usize = 1000;
4
5const READONLY_PREFIXES: &[&str] = &["select", "with", "explain", "show", "table", "values"];
6
7const FORBIDDEN_KEYWORDS: &[&str] = &[
8    " drop ",
9    " delete ",
10    " insert ",
11    " update ",
12    " alter ",
13    " create ",
14    " truncate ",
15    " grant ",
16    " revoke ",
17    " copy ",
18    " vacuum ",
19    " call ",
20    " lock ",
21    " set ",
22    " reset ",
23    " rename ",
24];
25
26#[derive(Debug, Clone, Copy, Error)]
27pub enum AdminSqlError {
28    #[error("SQL query is empty")]
29    Empty,
30    #[error("SQL query contains multiple statements; only one is allowed")]
31    MultipleStatements,
32    #[error("SQL query must begin with SELECT, WITH, EXPLAIN, SHOW, TABLE, or VALUES")]
33    NotReadOnly,
34    #[error("SQL query contains forbidden keyword for read-only mode")]
35    ForbiddenKeyword,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct AdminSql(String);
40
41impl AdminSql {
42    pub fn parse_readonly(raw: &str) -> Result<Self, AdminSqlError> {
43        let stripped = strip_comments(raw);
44        let trimmed = stripped.trim();
45        if trimmed.is_empty() {
46            return Err(AdminSqlError::Empty);
47        }
48
49        let without_terminator = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
50        if without_terminator.contains(';') {
51            return Err(AdminSqlError::MultipleStatements);
52        }
53
54        let lower = without_terminator.to_lowercase();
55        if !READONLY_PREFIXES
56            .iter()
57            .any(|p| starts_with_word(&lower, p))
58        {
59            return Err(AdminSqlError::NotReadOnly);
60        }
61
62        let padded = format!(" {lower} ");
63        if FORBIDDEN_KEYWORDS.iter().any(|kw| padded.contains(kw)) {
64            return Err(AdminSqlError::ForbiddenKeyword);
65        }
66
67        Ok(Self(without_terminator.to_string()))
68    }
69
70    pub fn parse_unrestricted(raw: &str) -> Result<Self, AdminSqlError> {
71        let stripped = strip_comments(raw);
72        let trimmed = stripped.trim();
73        if trimmed.is_empty() {
74            return Err(AdminSqlError::Empty);
75        }
76
77        let without_terminator = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
78        if without_terminator.contains(';') {
79            return Err(AdminSqlError::MultipleStatements);
80        }
81
82        Ok(Self(without_terminator.to_string()))
83    }
84
85    pub fn as_str(&self) -> &str {
86        &self.0
87    }
88}
89
90fn strip_comments(raw: &str) -> String {
91    let mut out = String::with_capacity(raw.len());
92    let mut chars = raw.chars().peekable();
93    while let Some(c) = chars.next() {
94        if c == '-' && chars.peek() == Some(&'-') {
95            for nc in chars.by_ref() {
96                if nc == '\n' {
97                    out.push('\n');
98                    break;
99                }
100            }
101            continue;
102        }
103        if c == '/' && chars.peek() == Some(&'*') {
104            chars.next();
105            let mut prev = '\0';
106            for nc in chars.by_ref() {
107                if prev == '*' && nc == '/' {
108                    break;
109                }
110                prev = nc;
111            }
112            continue;
113        }
114        out.push(c);
115    }
116    out
117}
118
119fn starts_with_word(haystack: &str, needle: &str) -> bool {
120    if !haystack.starts_with(needle) {
121        return false;
122    }
123    haystack[needle.len()..]
124        .chars()
125        .next()
126        .is_none_or(|c| c.is_whitespace() || c == '(' || c == ';')
127}