1use super::{get_i32, get_string, get_string_array, make_tool};
11use crate::db::Database;
12use crate::error::{ErrorCode, ToolError};
13use crate::format::{OutputFormat, ToolResult};
14use anyhow::Result;
15use rmcp::model::{Tool, ToolAnnotations};
16use serde_json::{Value, json};
17use std::time::Duration;
18
19const DEFAULT_ROW_LIMIT: i32 = 100;
21
22const MAX_ROW_LIMIT: i32 = 1000;
24
25const QUERY_TIMEOUT_SECS: u64 = 5;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum QueryFormat {
31 Json,
32 Csv,
33 Markdown,
34}
35
36impl QueryFormat {
37 fn from_str(s: &str) -> Option<Self> {
38 match s.to_lowercase().as_str() {
39 "json" => Some(Self::Json),
40 "csv" => Some(Self::Csv),
41 "markdown" | "md" => Some(Self::Markdown),
42 _ => None,
43 }
44 }
45}
46
47pub fn get_tools() -> Vec<Tool> {
49 let mut tool = make_tool(
50 "query",
51 "Execute a read-only SQL query against the task database. REQUIRES USER PERMISSION. \
52 Only SELECT statements are allowed. Useful for custom queries, debugging, and \
53 advanced reporting. Returns columns, rows, and metadata.",
54 json!({
55 "sql": {
56 "type": "string",
57 "description": "SQL SELECT query to execute. Only SELECT statements are allowed."
58 },
59 "params": {
60 "type": "array",
61 "items": { "type": "string" },
62 "description": "Bind parameters for the query (use ? placeholders in SQL)"
63 },
64 "limit": {
65 "type": "integer",
66 "description": "Maximum number of rows to return (default: 100, max: 1000)"
67 },
68 "format": {
69 "type": "string",
70 "enum": ["json", "csv", "markdown"],
71 "description": "Output format for results (default: json)"
72 }
73 }),
74 vec!["sql"],
75 );
76
77 tool.annotations = Some(ToolAnnotations {
81 title: Some("SQL Query".into()),
82 read_only_hint: Some(true),
83 destructive_hint: Some(false),
84 idempotent_hint: Some(true),
85 open_world_hint: Some(false),
86 });
87
88 vec![tool]
89}
90
91#[allow(clippy::result_large_err)]
93fn validate_readonly_sql(sql: &str) -> Result<(), ToolError> {
94 let normalized = sql.trim().to_uppercase();
96
97 let forbidden_prefixes = [
99 "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "TRUNCATE", "REPLACE", "UPSERT",
100 "MERGE", "GRANT", "REVOKE", "ATTACH", "DETACH", "VACUUM", "REINDEX", "ANALYZE",
101 "PRAGMA", ];
103
104 let first_word = normalized.split_whitespace().next().unwrap_or("");
106
107 if first_word != "SELECT" && first_word != "WITH" {
108 return Err(ToolError::new(
109 ErrorCode::InvalidFieldValue,
110 format!(
111 "Only SELECT queries are allowed. Got: {}",
112 if first_word.len() > 20 {
113 &first_word[..20]
114 } else {
115 first_word
116 }
117 ),
118 )
119 .with_field("sql"));
120 }
121
122 if first_word == "WITH" {
124 for forbidden in &forbidden_prefixes {
126 let pattern = format!(r"\b{}\b", forbidden);
128 if let Ok(re) = regex_lite::Regex::new(&pattern)
129 && re.is_match(&normalized)
130 {
131 return Err(ToolError::new(
132 ErrorCode::InvalidFieldValue,
133 format!("{} statements are not allowed in queries", forbidden),
134 )
135 .with_field("sql"));
136 }
137 }
138 }
139
140 let semicolon_count = sql.matches(';').count();
143 if semicolon_count > 1 {
144 return Err(ToolError::new(
145 ErrorCode::InvalidFieldValue,
146 "Multiple SQL statements are not allowed",
147 )
148 .with_field("sql"));
149 }
150
151 for forbidden in &forbidden_prefixes {
153 let pattern = format!(r"\b{}\s+", forbidden);
155 if let Ok(re) = regex_lite::Regex::new(&pattern)
156 && re.is_match(&normalized)
157 {
158 return Err(ToolError::new(
159 ErrorCode::InvalidFieldValue,
160 format!("{} statements are not allowed", forbidden),
161 )
162 .with_field("sql"));
163 }
164 }
165
166 Ok(())
167}
168
169pub fn query(db: &Database, default_format: OutputFormat, args: Value) -> Result<ToolResult> {
171 let sql = get_string(&args, "sql").ok_or_else(|| ToolError::missing_field("sql"))?;
172
173 let params = get_string_array(&args, "params").unwrap_or_default();
174
175 let limit = get_i32(&args, "limit")
176 .map(|l| l.clamp(1, MAX_ROW_LIMIT))
177 .unwrap_or(DEFAULT_ROW_LIMIT);
178
179 let format = get_string(&args, "format")
181 .and_then(|f| QueryFormat::from_str(&f))
182 .unwrap_or(match default_format {
183 OutputFormat::Json => QueryFormat::Json,
184 OutputFormat::Markdown => QueryFormat::Markdown,
185 });
186
187 validate_readonly_sql(&sql)?;
189
190 let result = db.with_conn(|conn| {
192 conn.busy_timeout(Duration::from_secs(QUERY_TIMEOUT_SECS))?;
194
195 let mut stmt = conn.prepare(&sql)?;
197
198 let column_count = stmt.column_count();
200 let columns: Vec<String> = (0..column_count)
201 .map(|i| stmt.column_name(i).unwrap_or("?").to_string())
202 .collect();
203
204 let params_refs: Vec<&dyn rusqlite::ToSql> =
206 params.iter().map(|s| s as &dyn rusqlite::ToSql).collect();
207
208 let mut rows_data: Vec<Vec<Value>> = Vec::new();
210 let mut row_iter = stmt.query(params_refs.as_slice())?;
211
212 let mut count = 0;
213 while let Some(row) = row_iter.next()? {
214 if count >= limit {
215 break;
216 }
217
218 let mut row_values: Vec<Value> = Vec::with_capacity(column_count);
219 for i in 0..column_count {
220 let value: Value = match row.get_ref(i)? {
221 rusqlite::types::ValueRef::Null => Value::Null,
222 rusqlite::types::ValueRef::Integer(i) => json!(i),
223 rusqlite::types::ValueRef::Real(f) => json!(f),
224 rusqlite::types::ValueRef::Text(s) => {
225 json!(String::from_utf8_lossy(s).to_string())
226 }
227 rusqlite::types::ValueRef::Blob(b) => {
228 json!(base64::Engine::encode(
229 &base64::engine::general_purpose::STANDARD,
230 b
231 ))
232 }
233 };
234 row_values.push(value);
235 }
236 rows_data.push(row_values);
237 count += 1;
238 }
239
240 let has_more = row_iter.next()?.is_some();
242
243 Ok((columns, rows_data, count, has_more))
244 })?;
245
246 let (columns, rows_data, row_count, truncated) = result;
247
248 match format {
250 QueryFormat::Json => {
251 let rows: Vec<Value> = rows_data
253 .iter()
254 .map(|row| {
255 let obj: serde_json::Map<String, Value> = columns
256 .iter()
257 .zip(row.iter())
258 .map(|(col, val)| (col.clone(), val.clone()))
259 .collect();
260 Value::Object(obj)
261 })
262 .collect();
263
264 Ok(ToolResult::Json(json!({
265 "columns": columns,
266 "rows": rows,
267 "row_count": row_count,
268 "truncated": truncated,
269 "limit": limit
270 })))
271 }
272 QueryFormat::Csv => {
273 let mut csv = String::new();
274 csv.push_str(&columns.join(","));
276 csv.push('\n');
277 for row in &rows_data {
279 let values: Vec<String> = row
280 .iter()
281 .map(|v| match v {
282 Value::Null => String::new(),
283 Value::String(s) => {
284 if s.contains(',') || s.contains('"') || s.contains('\n') {
286 format!("\"{}\"", s.replace('"', "\"\""))
287 } else {
288 s.clone()
289 }
290 }
291 _ => v.to_string(),
292 })
293 .collect();
294 csv.push_str(&values.join(","));
295 csv.push('\n');
296 }
297
298 if truncated {
300 csv.push_str(&format!("\n# Results truncated at {} rows\n", limit));
301 }
302 Ok(ToolResult::Raw(csv))
303 }
304 QueryFormat::Markdown => {
305 let mut md = String::new();
306
307 if columns.is_empty() {
308 md.push_str("*No columns*\n");
309 } else {
310 md.push_str("| ");
312 md.push_str(&columns.join(" | "));
313 md.push_str(" |\n");
314
315 md.push_str("| ");
317 md.push_str(
318 &columns
319 .iter()
320 .map(|_| "---")
321 .collect::<Vec<_>>()
322 .join(" | "),
323 );
324 md.push_str(" |\n");
325
326 for row in &rows_data {
328 md.push_str("| ");
329 let values: Vec<String> = row
330 .iter()
331 .map(|v| match v {
332 Value::Null => String::from("*null*"),
333 Value::String(s) => s.replace('|', "\\|"),
334 _ => v.to_string(),
335 })
336 .collect();
337 md.push_str(&values.join(" | "));
338 md.push_str(" |\n");
339 }
340 }
341
342 if truncated {
343 md.push_str(&format!("\n*Results truncated at {} rows*\n", limit));
344 }
345
346 Ok(ToolResult::Raw(md))
347 }
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_validate_readonly_select() {
357 assert!(validate_readonly_sql("SELECT * FROM tasks").is_ok());
358 assert!(validate_readonly_sql(" SELECT id FROM tasks WHERE status = 'pending' ").is_ok());
359 assert!(validate_readonly_sql("select count(*) from tasks").is_ok());
360 }
361
362 #[test]
363 fn test_validate_readonly_with_cte() {
364 assert!(validate_readonly_sql(
365 "WITH task_counts AS (SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status) SELECT * FROM task_counts"
366 ).is_ok());
367 }
368
369 #[test]
370 fn test_validate_readonly_rejects_insert() {
371 let result = validate_readonly_sql("INSERT INTO tasks (title) VALUES ('test')");
372 assert!(result.is_err());
373 assert!(result.unwrap_err().message.contains("INSERT"));
374 }
375
376 #[test]
377 fn test_validate_readonly_rejects_update() {
378 let result = validate_readonly_sql("UPDATE tasks SET status = 'done'");
379 assert!(result.is_err());
380 }
381
382 #[test]
383 fn test_validate_readonly_rejects_delete() {
384 let result = validate_readonly_sql("DELETE FROM tasks WHERE id = 'xxx'");
385 assert!(result.is_err());
386 }
387
388 #[test]
389 fn test_validate_readonly_rejects_drop() {
390 let result = validate_readonly_sql("DROP TABLE tasks");
391 assert!(result.is_err());
392 }
393
394 #[test]
395 fn test_validate_readonly_rejects_multiple_statements() {
396 let result = validate_readonly_sql("SELECT 1; DROP TABLE tasks;");
397 assert!(result.is_err());
398 assert!(result.unwrap_err().message.contains("Multiple"));
399 }
400
401 #[test]
402 fn test_validate_readonly_allows_column_names_with_keywords() {
403 assert!(validate_readonly_sql("SELECT deleted_at FROM tasks").is_ok());
405 assert!(validate_readonly_sql("SELECT updated_at, created_at FROM tasks").is_ok());
406 }
407
408 #[test]
409 fn test_query_format_parsing() {
410 assert_eq!(QueryFormat::from_str("json"), Some(QueryFormat::Json));
411 assert_eq!(QueryFormat::from_str("JSON"), Some(QueryFormat::Json));
412 assert_eq!(QueryFormat::from_str("csv"), Some(QueryFormat::Csv));
413 assert_eq!(
414 QueryFormat::from_str("markdown"),
415 Some(QueryFormat::Markdown)
416 );
417 assert_eq!(QueryFormat::from_str("md"), Some(QueryFormat::Markdown));
418 assert_eq!(QueryFormat::from_str("invalid"), None);
419 }
420}