systemprompt_database/admin/
admin_sql.rs1use thiserror::Error;
2
3pub const DEFAULT_READONLY_ROW_LIMIT: usize = 1000;
4
5const READONLY_PREFIXES: &[&str] = &["select", "with", "explain", "show", "table", "values"];
6
7const FORBIDDEN_KEYWORDS: &[&str] = &[
8 " drop ",
9 " delete ",
10 " insert ",
11 " update ",
12 " alter ",
13 " create ",
14 " truncate ",
15 " grant ",
16 " revoke ",
17 " copy ",
18 " vacuum ",
19 " call ",
20 " lock ",
21 " set ",
22 " reset ",
23 " rename ",
24];
25
26#[derive(Debug, Clone, Copy, Error)]
27pub enum AdminSqlError {
28 #[error("SQL query is empty")]
29 Empty,
30 #[error("SQL query contains multiple statements; only one is allowed")]
31 MultipleStatements,
32 #[error("SQL query must begin with SELECT, WITH, EXPLAIN, SHOW, TABLE, or VALUES")]
33 NotReadOnly,
34 #[error("SQL query contains forbidden keyword for read-only mode")]
35 ForbiddenKeyword,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct AdminSql(String);
40
41impl AdminSql {
42 pub fn parse_readonly(raw: &str) -> Result<Self, AdminSqlError> {
43 let stripped = strip_comments(raw);
44 let trimmed = stripped.trim();
45 if trimmed.is_empty() {
46 return Err(AdminSqlError::Empty);
47 }
48
49 let without_terminator = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
50 if without_terminator.contains(';') {
51 return Err(AdminSqlError::MultipleStatements);
52 }
53
54 let lower = without_terminator.to_lowercase();
55 if !READONLY_PREFIXES
56 .iter()
57 .any(|p| starts_with_word(&lower, p))
58 {
59 return Err(AdminSqlError::NotReadOnly);
60 }
61
62 let padded = format!(" {lower} ");
63 if FORBIDDEN_KEYWORDS.iter().any(|kw| padded.contains(kw)) {
64 return Err(AdminSqlError::ForbiddenKeyword);
65 }
66
67 Ok(Self(without_terminator.to_string()))
68 }
69
70 pub fn parse_unrestricted(raw: &str) -> Result<Self, AdminSqlError> {
71 let stripped = strip_comments(raw);
72 let trimmed = stripped.trim();
73 if trimmed.is_empty() {
74 return Err(AdminSqlError::Empty);
75 }
76
77 let without_terminator = trimmed.strip_suffix(';').unwrap_or(trimmed).trim_end();
78 if without_terminator.contains(';') {
79 return Err(AdminSqlError::MultipleStatements);
80 }
81
82 Ok(Self(without_terminator.to_string()))
83 }
84
85 pub fn as_str(&self) -> &str {
86 &self.0
87 }
88}
89
90fn strip_comments(raw: &str) -> String {
91 let mut out = String::with_capacity(raw.len());
92 let mut chars = raw.chars().peekable();
93 while let Some(c) = chars.next() {
94 if c == '-' && chars.peek() == Some(&'-') {
95 for nc in chars.by_ref() {
96 if nc == '\n' {
97 out.push('\n');
98 break;
99 }
100 }
101 continue;
102 }
103 if c == '/' && chars.peek() == Some(&'*') {
104 chars.next();
105 let mut prev = '\0';
106 for nc in chars.by_ref() {
107 if prev == '*' && nc == '/' {
108 break;
109 }
110 prev = nc;
111 }
112 continue;
113 }
114 out.push(c);
115 }
116 out
117}
118
119fn starts_with_word(haystack: &str, needle: &str) -> bool {
120 if !haystack.starts_with(needle) {
121 return false;
122 }
123 haystack[needle.len()..]
124 .chars()
125 .next()
126 .is_none_or(|c| c.is_whitespace() || c == '(' || c == ';')
127}