systemprompt_database/admin/
admin_sql.rs1use thiserror::Error;
9
10pub const DEFAULT_READONLY_ROW_LIMIT: usize = 1000;
11
12const READONLY_PREFIXES: &[&str] = &["select", "with", "explain", "show", "table", "values"];
13
14const FORBIDDEN_KEYWORDS: &[&str] = &[
15 " drop ",
16 " delete ",
17 " insert ",
18 " update ",
19 " alter ",
20 " create ",
21 " truncate ",
22 " grant ",
23 " revoke ",
24 " copy ",
25 " vacuum ",
26 " call ",
27 " lock ",
28 " set ",
29 " reset ",
30 " rename ",
31];
32
33#[derive(Debug, Clone, Copy, Error)]
34pub enum AdminSqlError {
35 #[error("SQL query is empty")]
36 Empty,
37 #[error("SQL query contains multiple statements; only one is allowed")]
38 MultipleStatements,
39 #[error("SQL query must begin with SELECT, WITH, EXPLAIN, SHOW, TABLE, or VALUES")]
40 NotReadOnly,
41 #[error("SQL query contains forbidden keyword for read-only mode")]
42 ForbiddenKeyword,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct AdminSql(String);
47
48impl AdminSql {
49 pub fn parse_readonly(raw: &str) -> Result<Self, AdminSqlError> {
50 let stripped = strip_comments(raw);
51 let trimmed = stripped.trim();
52 if trimmed.is_empty() {
53 return Err(AdminSqlError::Empty);
54 }
55
56 let without_terminator = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
57 if without_terminator.contains(';') {
58 return Err(AdminSqlError::MultipleStatements);
59 }
60
61 let lower = without_terminator.to_lowercase();
62 if !READONLY_PREFIXES
63 .iter()
64 .any(|p| starts_with_word(&lower, p))
65 {
66 return Err(AdminSqlError::NotReadOnly);
67 }
68
69 let padded = format!(" {lower} ");
70 if FORBIDDEN_KEYWORDS.iter().any(|kw| padded.contains(kw)) {
71 return Err(AdminSqlError::ForbiddenKeyword);
72 }
73
74 Ok(Self(without_terminator.to_string()))
75 }
76
77 pub fn parse_unrestricted(raw: &str) -> Result<Self, AdminSqlError> {
78 let stripped = strip_comments(raw);
79 let trimmed = stripped.trim();
80 if trimmed.is_empty() {
81 return Err(AdminSqlError::Empty);
82 }
83
84 let without_terminator = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
85 if without_terminator.contains(';') {
86 return Err(AdminSqlError::MultipleStatements);
87 }
88
89 Ok(Self(without_terminator.to_string()))
90 }
91
92 pub fn as_str(&self) -> &str {
93 &self.0
94 }
95}
96
97fn strip_comments(raw: &str) -> String {
98 let mut out = String::with_capacity(raw.len());
99 let mut chars = raw.chars().peekable();
100 while let Some(c) = chars.next() {
101 if c == '-' && chars.peek() == Some(&'-') {
102 for nc in chars.by_ref() {
103 if nc == '\n' {
104 out.push('\n');
105 break;
106 }
107 }
108 continue;
109 }
110 if c == '/' && chars.peek() == Some(&'*') {
111 chars.next();
112 let mut prev = '\0';
113 for nc in chars.by_ref() {
114 if prev == '*' && nc == '/' {
115 break;
116 }
117 prev = nc;
118 }
119 continue;
120 }
121 out.push(c);
122 }
123 out
124}
125
126fn starts_with_word(haystack: &str, needle: &str) -> bool {
127 if !haystack.starts_with(needle) {
128 return false;
129 }
130 haystack[needle.len()..]
131 .chars()
132 .next()
133 .is_none_or(|c| c.is_whitespace() || c == '(' || c == ';')
134}