1use super::{get_i32, get_string, get_string_array, make_tool};
11use crate::db::Database;
12use crate::error::{ErrorCode, ToolError};
13use crate::format::{OutputFormat, ToolResult};
14use anyhow::Result;
15use rmcp::model::{Tool, ToolAnnotations};
16use serde_json::{json, Value};
17use std::time::Duration;
18
19const DEFAULT_ROW_LIMIT: i32 = 100;
21
22const MAX_ROW_LIMIT: i32 = 1000;
24
25const QUERY_TIMEOUT_SECS: u64 = 5;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum QueryFormat {
31 Json,
32 Csv,
33 Markdown,
34}
35
36impl QueryFormat {
37 fn from_str(s: &str) -> Option<Self> {
38 match s.to_lowercase().as_str() {
39 "json" => Some(Self::Json),
40 "csv" => Some(Self::Csv),
41 "markdown" | "md" => Some(Self::Markdown),
42 _ => None,
43 }
44 }
45}
46
47pub fn get_tools() -> Vec<Tool> {
49 let mut tool = make_tool(
50 "query",
51 "Execute a read-only SQL query against the task database. REQUIRES USER PERMISSION. \
52 Only SELECT statements are allowed. Useful for custom queries, debugging, and \
53 advanced reporting. Returns columns, rows, and metadata.",
54 json!({
55 "sql": {
56 "type": "string",
57 "description": "SQL SELECT query to execute. Only SELECT statements are allowed."
58 },
59 "params": {
60 "type": "array",
61 "items": { "type": "string" },
62 "description": "Bind parameters for the query (use ? placeholders in SQL)"
63 },
64 "limit": {
65 "type": "integer",
66 "description": "Maximum number of rows to return (default: 100, max: 1000)"
67 },
68 "format": {
69 "type": "string",
70 "enum": ["json", "csv", "markdown"],
71 "description": "Output format for results (default: json)"
72 }
73 }),
74 vec!["sql"],
75 );
76
77 tool.annotations = Some(ToolAnnotations {
81 title: Some("SQL Query".into()),
82 read_only_hint: Some(true),
83 destructive_hint: Some(false),
84 idempotent_hint: Some(true),
85 open_world_hint: Some(false),
86 });
87
88 vec![tool]
89}
90
91fn validate_readonly_sql(sql: &str) -> Result<(), ToolError> {
93 let normalized = sql.trim().to_uppercase();
95
96 let forbidden_prefixes = [
98 "INSERT",
99 "UPDATE",
100 "DELETE",
101 "DROP",
102 "CREATE",
103 "ALTER",
104 "TRUNCATE",
105 "REPLACE",
106 "UPSERT",
107 "MERGE",
108 "GRANT",
109 "REVOKE",
110 "ATTACH",
111 "DETACH",
112 "VACUUM",
113 "REINDEX",
114 "ANALYZE",
115 "PRAGMA", ];
117
118 let first_word = normalized.split_whitespace().next().unwrap_or("");
120
121 if first_word != "SELECT" && first_word != "WITH" {
122 return Err(ToolError::new(
123 ErrorCode::InvalidFieldValue,
124 format!(
125 "Only SELECT queries are allowed. Got: {}",
126 if first_word.len() > 20 {
127 &first_word[..20]
128 } else {
129 first_word
130 }
131 ),
132 )
133 .with_field("sql"));
134 }
135
136 if first_word == "WITH" {
138 for forbidden in &forbidden_prefixes {
140 let pattern = format!(r"\b{}\b", forbidden);
142 if let Ok(re) = regex_lite::Regex::new(&pattern)
143 && re.is_match(&normalized)
144 {
145 return Err(ToolError::new(
146 ErrorCode::InvalidFieldValue,
147 format!("{} statements are not allowed in queries", forbidden),
148 )
149 .with_field("sql"));
150 }
151 }
152 }
153
154 let semicolon_count = sql.matches(';').count();
157 if semicolon_count > 1 {
158 return Err(ToolError::new(
159 ErrorCode::InvalidFieldValue,
160 "Multiple SQL statements are not allowed",
161 )
162 .with_field("sql"));
163 }
164
165 for forbidden in &forbidden_prefixes {
167 let pattern = format!(r"\b{}\s+", forbidden);
169 if let Ok(re) = regex_lite::Regex::new(&pattern)
170 && re.is_match(&normalized)
171 {
172 return Err(ToolError::new(
173 ErrorCode::InvalidFieldValue,
174 format!("{} statements are not allowed", forbidden),
175 )
176 .with_field("sql"));
177 }
178 }
179
180 Ok(())
181}
182
183pub fn query(db: &Database, default_format: OutputFormat, args: Value) -> Result<ToolResult> {
185 let sql = get_string(&args, "sql").ok_or_else(|| ToolError::missing_field("sql"))?;
186
187 let params = get_string_array(&args, "params").unwrap_or_default();
188
189 let limit = get_i32(&args, "limit")
190 .map(|l| l.clamp(1, MAX_ROW_LIMIT))
191 .unwrap_or(DEFAULT_ROW_LIMIT);
192
193 let format = get_string(&args, "format")
195 .and_then(|f| QueryFormat::from_str(&f))
196 .unwrap_or_else(|| match default_format {
197 OutputFormat::Json => QueryFormat::Json,
198 OutputFormat::Markdown => QueryFormat::Markdown,
199 });
200
201 validate_readonly_sql(&sql)?;
203
204 let result = db.with_conn(|conn| {
206 conn.busy_timeout(Duration::from_secs(QUERY_TIMEOUT_SECS))?;
208
209 let mut stmt = conn.prepare(&sql)?;
211
212 let column_count = stmt.column_count();
214 let columns: Vec<String> = (0..column_count)
215 .map(|i| stmt.column_name(i).unwrap_or("?").to_string())
216 .collect();
217
218 let params_refs: Vec<&dyn rusqlite::ToSql> = params
220 .iter()
221 .map(|s| s as &dyn rusqlite::ToSql)
222 .collect();
223
224 let mut rows_data: Vec<Vec<Value>> = Vec::new();
226 let mut row_iter = stmt.query(params_refs.as_slice())?;
227
228 let mut count = 0;
229 while let Some(row) = row_iter.next()? {
230 if count >= limit {
231 break;
232 }
233
234 let mut row_values: Vec<Value> = Vec::with_capacity(column_count);
235 for i in 0..column_count {
236 let value: Value = match row.get_ref(i)? {
237 rusqlite::types::ValueRef::Null => Value::Null,
238 rusqlite::types::ValueRef::Integer(i) => json!(i),
239 rusqlite::types::ValueRef::Real(f) => json!(f),
240 rusqlite::types::ValueRef::Text(s) => {
241 json!(String::from_utf8_lossy(s).to_string())
242 }
243 rusqlite::types::ValueRef::Blob(b) => {
244 json!(base64::Engine::encode(
245 &base64::engine::general_purpose::STANDARD,
246 b
247 ))
248 }
249 };
250 row_values.push(value);
251 }
252 rows_data.push(row_values);
253 count += 1;
254 }
255
256 let has_more = row_iter.next()?.is_some();
258
259 Ok((columns, rows_data, count, has_more))
260 })?;
261
262 let (columns, rows_data, row_count, truncated) = result;
263
264 match format {
266 QueryFormat::Json => {
267 let rows: Vec<Value> = rows_data
269 .iter()
270 .map(|row| {
271 let obj: serde_json::Map<String, Value> = columns
272 .iter()
273 .zip(row.iter())
274 .map(|(col, val)| (col.clone(), val.clone()))
275 .collect();
276 Value::Object(obj)
277 })
278 .collect();
279
280 Ok(ToolResult::Json(json!({
281 "columns": columns,
282 "rows": rows,
283 "row_count": row_count,
284 "truncated": truncated,
285 "limit": limit
286 })))
287 }
288 QueryFormat::Csv => {
289 let mut csv = String::new();
290 csv.push_str(&columns.join(","));
292 csv.push('\n');
293 for row in &rows_data {
295 let values: Vec<String> = row
296 .iter()
297 .map(|v| match v {
298 Value::Null => String::new(),
299 Value::String(s) => {
300 if s.contains(',') || s.contains('"') || s.contains('\n') {
302 format!("\"{}\"", s.replace('"', "\"\""))
303 } else {
304 s.clone()
305 }
306 }
307 _ => v.to_string(),
308 })
309 .collect();
310 csv.push_str(&values.join(","));
311 csv.push('\n');
312 }
313
314 if truncated {
316 csv.push_str(&format!("\n# Results truncated at {} rows\n", limit));
317 }
318 Ok(ToolResult::Raw(csv))
319 }
320 QueryFormat::Markdown => {
321 let mut md = String::new();
322
323 if columns.is_empty() {
324 md.push_str("*No columns*\n");
325 } else {
326 md.push_str("| ");
328 md.push_str(&columns.join(" | "));
329 md.push_str(" |\n");
330
331 md.push_str("| ");
333 md.push_str(
334 &columns
335 .iter()
336 .map(|_| "---")
337 .collect::<Vec<_>>()
338 .join(" | "),
339 );
340 md.push_str(" |\n");
341
342 for row in &rows_data {
344 md.push_str("| ");
345 let values: Vec<String> = row
346 .iter()
347 .map(|v| match v {
348 Value::Null => String::from("*null*"),
349 Value::String(s) => s.replace('|', "\\|"),
350 _ => v.to_string(),
351 })
352 .collect();
353 md.push_str(&values.join(" | "));
354 md.push_str(" |\n");
355 }
356 }
357
358 if truncated {
359 md.push_str(&format!("\n*Results truncated at {} rows*\n", limit));
360 }
361
362 Ok(ToolResult::Raw(md))
363 }
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_validate_readonly_select() {
373 assert!(validate_readonly_sql("SELECT * FROM tasks").is_ok());
374 assert!(validate_readonly_sql(" SELECT id FROM tasks WHERE status = 'pending' ").is_ok());
375 assert!(validate_readonly_sql("select count(*) from tasks").is_ok());
376 }
377
378 #[test]
379 fn test_validate_readonly_with_cte() {
380 assert!(validate_readonly_sql(
381 "WITH task_counts AS (SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status) SELECT * FROM task_counts"
382 ).is_ok());
383 }
384
385 #[test]
386 fn test_validate_readonly_rejects_insert() {
387 let result = validate_readonly_sql("INSERT INTO tasks (title) VALUES ('test')");
388 assert!(result.is_err());
389 assert!(result.unwrap_err().message.contains("INSERT"));
390 }
391
392 #[test]
393 fn test_validate_readonly_rejects_update() {
394 let result = validate_readonly_sql("UPDATE tasks SET status = 'done'");
395 assert!(result.is_err());
396 }
397
398 #[test]
399 fn test_validate_readonly_rejects_delete() {
400 let result = validate_readonly_sql("DELETE FROM tasks WHERE id = 'xxx'");
401 assert!(result.is_err());
402 }
403
404 #[test]
405 fn test_validate_readonly_rejects_drop() {
406 let result = validate_readonly_sql("DROP TABLE tasks");
407 assert!(result.is_err());
408 }
409
410 #[test]
411 fn test_validate_readonly_rejects_multiple_statements() {
412 let result = validate_readonly_sql("SELECT 1; DROP TABLE tasks;");
413 assert!(result.is_err());
414 assert!(result.unwrap_err().message.contains("Multiple"));
415 }
416
417 #[test]
418 fn test_validate_readonly_allows_column_names_with_keywords() {
419 assert!(validate_readonly_sql("SELECT deleted_at FROM tasks").is_ok());
421 assert!(validate_readonly_sql("SELECT updated_at, created_at FROM tasks").is_ok());
422 }
423
424 #[test]
425 fn test_query_format_parsing() {
426 assert_eq!(QueryFormat::from_str("json"), Some(QueryFormat::Json));
427 assert_eq!(QueryFormat::from_str("JSON"), Some(QueryFormat::Json));
428 assert_eq!(QueryFormat::from_str("csv"), Some(QueryFormat::Csv));
429 assert_eq!(QueryFormat::from_str("markdown"), Some(QueryFormat::Markdown));
430 assert_eq!(QueryFormat::from_str("md"), Some(QueryFormat::Markdown));
431 assert_eq!(QueryFormat::from_str("invalid"), None);
432 }
433}