1use super::{get_i32, get_string, get_string_array, make_tool};
11use crate::db::Database;
12use crate::error::{ErrorCode, ToolError};
13use crate::format::{OutputFormat, ToolResult};
14use anyhow::Result;
15use rmcp::model::{Tool, ToolAnnotations};
16use serde_json::{Value, json};
17use std::time::Duration;
18
19const DEFAULT_ROW_LIMIT: i32 = 100;
21
22const MAX_ROW_LIMIT: i32 = 1000;
24
25const QUERY_TIMEOUT_SECS: u64 = 5;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum QueryFormat {
31 Json,
32 Csv,
33 Markdown,
34}
35
36impl QueryFormat {
37 fn from_str(s: &str) -> Option<Self> {
38 match s.to_lowercase().as_str() {
39 "json" => Some(Self::Json),
40 "csv" => Some(Self::Csv),
41 "markdown" | "md" => Some(Self::Markdown),
42 _ => None,
43 }
44 }
45}
46
47pub fn get_tools() -> Vec<Tool> {
49 let mut tool = make_tool(
50 "query",
51 "Execute a read-only SQL query against the task database. REQUIRES USER PERMISSION. \
52 Only SELECT statements are allowed. Useful for custom queries, debugging, and \
53 advanced reporting. Returns columns, rows, and metadata.",
54 json!({
55 "sql": {
56 "type": "string",
57 "description": "SQL SELECT query to execute. Only SELECT statements are allowed."
58 },
59 "params": {
60 "type": "array",
61 "items": { "type": "string" },
62 "description": "Bind parameters for the query (use ? placeholders in SQL)"
63 },
64 "limit": {
65 "type": "integer",
66 "description": "Maximum number of rows to return (default: 100, max: 1000)"
67 },
68 "format": {
69 "type": "string",
70 "enum": ["json", "csv", "markdown"],
71 "description": "Output format for results (default: json)"
72 }
73 }),
74 vec!["sql"],
75 );
76
77 tool.annotations = Some(ToolAnnotations {
81 title: Some("SQL Query".into()),
82 read_only_hint: Some(true),
83 destructive_hint: Some(false),
84 idempotent_hint: Some(true),
85 open_world_hint: Some(false),
86 });
87
88 vec![tool]
89}
90
91fn validate_readonly_sql(sql: &str) -> Result<(), ToolError> {
93 let normalized = sql.trim().to_uppercase();
95
96 let forbidden_prefixes = [
98 "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "TRUNCATE", "REPLACE", "UPSERT",
99 "MERGE", "GRANT", "REVOKE", "ATTACH", "DETACH", "VACUUM", "REINDEX", "ANALYZE",
100 "PRAGMA", ];
102
103 let first_word = normalized.split_whitespace().next().unwrap_or("");
105
106 if first_word != "SELECT" && first_word != "WITH" {
107 return Err(ToolError::new(
108 ErrorCode::InvalidFieldValue,
109 format!(
110 "Only SELECT queries are allowed. Got: {}",
111 if first_word.len() > 20 {
112 &first_word[..20]
113 } else {
114 first_word
115 }
116 ),
117 )
118 .with_field("sql"));
119 }
120
121 if first_word == "WITH" {
123 for forbidden in &forbidden_prefixes {
125 let pattern = format!(r"\b{}\b", forbidden);
127 if let Ok(re) = regex_lite::Regex::new(&pattern)
128 && re.is_match(&normalized)
129 {
130 return Err(ToolError::new(
131 ErrorCode::InvalidFieldValue,
132 format!("{} statements are not allowed in queries", forbidden),
133 )
134 .with_field("sql"));
135 }
136 }
137 }
138
139 let semicolon_count = sql.matches(';').count();
142 if semicolon_count > 1 {
143 return Err(ToolError::new(
144 ErrorCode::InvalidFieldValue,
145 "Multiple SQL statements are not allowed",
146 )
147 .with_field("sql"));
148 }
149
150 for forbidden in &forbidden_prefixes {
152 let pattern = format!(r"\b{}\s+", forbidden);
154 if let Ok(re) = regex_lite::Regex::new(&pattern)
155 && re.is_match(&normalized)
156 {
157 return Err(ToolError::new(
158 ErrorCode::InvalidFieldValue,
159 format!("{} statements are not allowed", forbidden),
160 )
161 .with_field("sql"));
162 }
163 }
164
165 Ok(())
166}
167
168pub fn query(db: &Database, default_format: OutputFormat, args: Value) -> Result<ToolResult> {
170 let sql = get_string(&args, "sql").ok_or_else(|| ToolError::missing_field("sql"))?;
171
172 let params = get_string_array(&args, "params").unwrap_or_default();
173
174 let limit = get_i32(&args, "limit")
175 .map(|l| l.clamp(1, MAX_ROW_LIMIT))
176 .unwrap_or(DEFAULT_ROW_LIMIT);
177
178 let format = get_string(&args, "format")
180 .and_then(|f| QueryFormat::from_str(&f))
181 .unwrap_or(match default_format {
182 OutputFormat::Json => QueryFormat::Json,
183 OutputFormat::Markdown => QueryFormat::Markdown,
184 });
185
186 validate_readonly_sql(&sql)?;
188
189 let result = db.with_conn(|conn| {
191 conn.busy_timeout(Duration::from_secs(QUERY_TIMEOUT_SECS))?;
193
194 let mut stmt = conn.prepare(&sql)?;
196
197 let column_count = stmt.column_count();
199 let columns: Vec<String> = (0..column_count)
200 .map(|i| stmt.column_name(i).unwrap_or("?").to_string())
201 .collect();
202
203 let params_refs: Vec<&dyn rusqlite::ToSql> =
205 params.iter().map(|s| s as &dyn rusqlite::ToSql).collect();
206
207 let mut rows_data: Vec<Vec<Value>> = Vec::new();
209 let mut row_iter = stmt.query(params_refs.as_slice())?;
210
211 let mut count = 0;
212 while let Some(row) = row_iter.next()? {
213 if count >= limit {
214 break;
215 }
216
217 let mut row_values: Vec<Value> = Vec::with_capacity(column_count);
218 for i in 0..column_count {
219 let value: Value = match row.get_ref(i)? {
220 rusqlite::types::ValueRef::Null => Value::Null,
221 rusqlite::types::ValueRef::Integer(i) => json!(i),
222 rusqlite::types::ValueRef::Real(f) => json!(f),
223 rusqlite::types::ValueRef::Text(s) => {
224 json!(String::from_utf8_lossy(s).to_string())
225 }
226 rusqlite::types::ValueRef::Blob(b) => {
227 json!(base64::Engine::encode(
228 &base64::engine::general_purpose::STANDARD,
229 b
230 ))
231 }
232 };
233 row_values.push(value);
234 }
235 rows_data.push(row_values);
236 count += 1;
237 }
238
239 let has_more = row_iter.next()?.is_some();
241
242 Ok((columns, rows_data, count, has_more))
243 })?;
244
245 let (columns, rows_data, row_count, truncated) = result;
246
247 match format {
249 QueryFormat::Json => {
250 let rows: Vec<Value> = rows_data
252 .iter()
253 .map(|row| {
254 let obj: serde_json::Map<String, Value> = columns
255 .iter()
256 .zip(row.iter())
257 .map(|(col, val)| (col.clone(), val.clone()))
258 .collect();
259 Value::Object(obj)
260 })
261 .collect();
262
263 Ok(ToolResult::Json(json!({
264 "columns": columns,
265 "rows": rows,
266 "row_count": row_count,
267 "truncated": truncated,
268 "limit": limit
269 })))
270 }
271 QueryFormat::Csv => {
272 let mut csv = String::new();
273 csv.push_str(&columns.join(","));
275 csv.push('\n');
276 for row in &rows_data {
278 let values: Vec<String> = row
279 .iter()
280 .map(|v| match v {
281 Value::Null => String::new(),
282 Value::String(s) => {
283 if s.contains(',') || s.contains('"') || s.contains('\n') {
285 format!("\"{}\"", s.replace('"', "\"\""))
286 } else {
287 s.clone()
288 }
289 }
290 _ => v.to_string(),
291 })
292 .collect();
293 csv.push_str(&values.join(","));
294 csv.push('\n');
295 }
296
297 if truncated {
299 csv.push_str(&format!("\n# Results truncated at {} rows\n", limit));
300 }
301 Ok(ToolResult::Raw(csv))
302 }
303 QueryFormat::Markdown => {
304 let mut md = String::new();
305
306 if columns.is_empty() {
307 md.push_str("*No columns*\n");
308 } else {
309 md.push_str("| ");
311 md.push_str(&columns.join(" | "));
312 md.push_str(" |\n");
313
314 md.push_str("| ");
316 md.push_str(
317 &columns
318 .iter()
319 .map(|_| "---")
320 .collect::<Vec<_>>()
321 .join(" | "),
322 );
323 md.push_str(" |\n");
324
325 for row in &rows_data {
327 md.push_str("| ");
328 let values: Vec<String> = row
329 .iter()
330 .map(|v| match v {
331 Value::Null => String::from("*null*"),
332 Value::String(s) => s.replace('|', "\\|"),
333 _ => v.to_string(),
334 })
335 .collect();
336 md.push_str(&values.join(" | "));
337 md.push_str(" |\n");
338 }
339 }
340
341 if truncated {
342 md.push_str(&format!("\n*Results truncated at {} rows*\n", limit));
343 }
344
345 Ok(ToolResult::Raw(md))
346 }
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_validate_readonly_select() {
356 assert!(validate_readonly_sql("SELECT * FROM tasks").is_ok());
357 assert!(validate_readonly_sql(" SELECT id FROM tasks WHERE status = 'pending' ").is_ok());
358 assert!(validate_readonly_sql("select count(*) from tasks").is_ok());
359 }
360
361 #[test]
362 fn test_validate_readonly_with_cte() {
363 assert!(validate_readonly_sql(
364 "WITH task_counts AS (SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status) SELECT * FROM task_counts"
365 ).is_ok());
366 }
367
368 #[test]
369 fn test_validate_readonly_rejects_insert() {
370 let result = validate_readonly_sql("INSERT INTO tasks (title) VALUES ('test')");
371 assert!(result.is_err());
372 assert!(result.unwrap_err().message.contains("INSERT"));
373 }
374
375 #[test]
376 fn test_validate_readonly_rejects_update() {
377 let result = validate_readonly_sql("UPDATE tasks SET status = 'done'");
378 assert!(result.is_err());
379 }
380
381 #[test]
382 fn test_validate_readonly_rejects_delete() {
383 let result = validate_readonly_sql("DELETE FROM tasks WHERE id = 'xxx'");
384 assert!(result.is_err());
385 }
386
387 #[test]
388 fn test_validate_readonly_rejects_drop() {
389 let result = validate_readonly_sql("DROP TABLE tasks");
390 assert!(result.is_err());
391 }
392
393 #[test]
394 fn test_validate_readonly_rejects_multiple_statements() {
395 let result = validate_readonly_sql("SELECT 1; DROP TABLE tasks;");
396 assert!(result.is_err());
397 assert!(result.unwrap_err().message.contains("Multiple"));
398 }
399
400 #[test]
401 fn test_validate_readonly_allows_column_names_with_keywords() {
402 assert!(validate_readonly_sql("SELECT deleted_at FROM tasks").is_ok());
404 assert!(validate_readonly_sql("SELECT updated_at, created_at FROM tasks").is_ok());
405 }
406
407 #[test]
408 fn test_query_format_parsing() {
409 assert_eq!(QueryFormat::from_str("json"), Some(QueryFormat::Json));
410 assert_eq!(QueryFormat::from_str("JSON"), Some(QueryFormat::Json));
411 assert_eq!(QueryFormat::from_str("csv"), Some(QueryFormat::Csv));
412 assert_eq!(
413 QueryFormat::from_str("markdown"),
414 Some(QueryFormat::Markdown)
415 );
416 assert_eq!(QueryFormat::from_str("md"), Some(QueryFormat::Markdown));
417 assert_eq!(QueryFormat::from_str("invalid"), None);
418 }
419}