Skip to main content

task_graph_mcp/tools/
query.rs

1//! Read-only SQL query tool.
2//!
3//! Provides a `query` tool for executing read-only SQL queries against the database.
4//! This tool is intended for advanced users and debugging purposes.
5//!
6//! SECURITY: This tool requires user permission before execution and only allows
7//! SELECT statements. INSERT, UPDATE, DELETE, DROP, and other modifying statements
8//! are rejected.
9
10use super::{get_i32, get_string, get_string_array, make_tool};
11use crate::db::Database;
12use crate::error::{ErrorCode, ToolError};
13use crate::format::{OutputFormat, ToolResult};
14use anyhow::Result;
15use rmcp::model::{Tool, ToolAnnotations};
16use serde_json::{Value, json};
17use std::time::Duration;
18
19/// Default row limit for query results.
20const DEFAULT_ROW_LIMIT: i32 = 100;
21
22/// Maximum allowed row limit.
23const MAX_ROW_LIMIT: i32 = 1000;
24
25/// Query execution timeout in seconds.
26const QUERY_TIMEOUT_SECS: u64 = 5;
27
28/// Output format for query results.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum QueryFormat {
31    Json,
32    Csv,
33    Markdown,
34}
35
36impl QueryFormat {
37    fn from_str(s: &str) -> Option<Self> {
38        match s.to_lowercase().as_str() {
39            "json" => Some(Self::Json),
40            "csv" => Some(Self::Csv),
41            "markdown" | "md" => Some(Self::Markdown),
42            _ => None,
43        }
44    }
45}
46
47/// Get all query-related tools.
48pub fn get_tools() -> Vec<Tool> {
49    let mut tool = make_tool(
50        "query",
51        "Execute a read-only SQL query against the task database. REQUIRES USER PERMISSION. \
52         Only SELECT statements are allowed. Useful for custom queries, debugging, and \
53         advanced reporting. Returns columns, rows, and metadata.",
54        json!({
55            "sql": {
56                "type": "string",
57                "description": "SQL SELECT query to execute. Only SELECT statements are allowed."
58            },
59            "params": {
60                "type": "array",
61                "items": { "type": "string" },
62                "description": "Bind parameters for the query (use ? placeholders in SQL)"
63            },
64            "limit": {
65                "type": "integer",
66                "description": "Maximum number of rows to return (default: 100, max: 1000)"
67            },
68            "format": {
69                "type": "string",
70                "enum": ["json", "csv", "markdown"],
71                "description": "Output format for results (default: json)"
72            }
73        }),
74        vec!["sql"],
75    );
76
77    // Add annotations to indicate this is a read-only but potentially sensitive tool
78    // The destructiveHint is false because we only allow SELECT
79    // readOnlyHint is true because we don't modify data
80    tool.annotations = Some(ToolAnnotations {
81        title: Some("SQL Query".into()),
82        read_only_hint: Some(true),
83        destructive_hint: Some(false),
84        idempotent_hint: Some(true),
85        open_world_hint: Some(false),
86    });
87
88    vec![tool]
89}
90
91/// Validate that a SQL query is read-only (SELECT only).
92#[allow(clippy::result_large_err)]
93fn validate_readonly_sql(sql: &str) -> Result<(), ToolError> {
94    // Normalize whitespace and convert to uppercase for checking
95    let normalized = sql.trim().to_uppercase();
96
97    // Check for forbidden statements
98    let forbidden_prefixes = [
99        "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "TRUNCATE", "REPLACE", "UPSERT",
100        "MERGE", "GRANT", "REVOKE", "ATTACH", "DETACH", "VACUUM", "REINDEX", "ANALYZE",
101        "PRAGMA", // Some PRAGMAs can modify settings
102    ];
103
104    // Get the first word (statement type)
105    let first_word = normalized.split_whitespace().next().unwrap_or("");
106
107    if first_word != "SELECT" && first_word != "WITH" {
108        return Err(ToolError::new(
109            ErrorCode::InvalidFieldValue,
110            format!(
111                "Only SELECT queries are allowed. Got: {}",
112                if first_word.len() > 20 {
113                    &first_word[..20]
114                } else {
115                    first_word
116                }
117            ),
118        )
119        .with_field("sql"));
120    }
121
122    // Additional check for CTEs (WITH ... SELECT is OK, but WITH ... INSERT/UPDATE/DELETE is not)
123    if first_word == "WITH" {
124        // Look for modification keywords after WITH clause
125        for forbidden in &forbidden_prefixes {
126            // Check if the forbidden keyword appears as a standalone word (not in quotes or names)
127            let pattern = format!(r"\b{}\b", forbidden);
128            if let Ok(re) = regex_lite::Regex::new(&pattern)
129                && re.is_match(&normalized)
130            {
131                return Err(ToolError::new(
132                    ErrorCode::InvalidFieldValue,
133                    format!("{} statements are not allowed in queries", forbidden),
134                )
135                .with_field("sql"));
136            }
137        }
138    }
139
140    // Check for semicolons that might indicate multiple statements
141    // (SQLite allows this but we want to prevent injection)
142    let semicolon_count = sql.matches(';').count();
143    if semicolon_count > 1 {
144        return Err(ToolError::new(
145            ErrorCode::InvalidFieldValue,
146            "Multiple SQL statements are not allowed",
147        )
148        .with_field("sql"));
149    }
150
151    // Check for forbidden keywords anywhere in the query (for subqueries or injection attempts)
152    for forbidden in &forbidden_prefixes {
153        // Use word boundary matching to avoid false positives like "DELETED_AT"
154        let pattern = format!(r"\b{}\s+", forbidden);
155        if let Ok(re) = regex_lite::Regex::new(&pattern)
156            && re.is_match(&normalized)
157        {
158            return Err(ToolError::new(
159                ErrorCode::InvalidFieldValue,
160                format!("{} statements are not allowed", forbidden),
161            )
162            .with_field("sql"));
163        }
164    }
165
166    Ok(())
167}
168
169/// Execute a read-only SQL query.
170pub fn query(db: &Database, default_format: OutputFormat, args: Value) -> Result<ToolResult> {
171    let sql = get_string(&args, "sql").ok_or_else(|| ToolError::missing_field("sql"))?;
172
173    let params = get_string_array(&args, "params").unwrap_or_default();
174
175    let limit = get_i32(&args, "limit")
176        .map(|l| l.clamp(1, MAX_ROW_LIMIT))
177        .unwrap_or(DEFAULT_ROW_LIMIT);
178
179    // Use explicit format if provided, otherwise use config default
180    let format = get_string(&args, "format")
181        .and_then(|f| QueryFormat::from_str(&f))
182        .unwrap_or(match default_format {
183            OutputFormat::Json => QueryFormat::Json,
184            OutputFormat::Markdown => QueryFormat::Markdown,
185        });
186
187    // Validate the query is read-only
188    validate_readonly_sql(&sql)?;
189
190    // Execute the query with timeout
191    let result = db.with_conn(|conn| {
192        // Set a busy timeout for this connection
193        conn.busy_timeout(Duration::from_secs(QUERY_TIMEOUT_SECS))?;
194
195        // Prepare the statement
196        let mut stmt = conn.prepare(&sql)?;
197
198        // Get column names
199        let column_count = stmt.column_count();
200        let columns: Vec<String> = (0..column_count)
201            .map(|i| stmt.column_name(i).unwrap_or("?").to_string())
202            .collect();
203
204        // Bind parameters
205        let params_refs: Vec<&dyn rusqlite::ToSql> =
206            params.iter().map(|s| s as &dyn rusqlite::ToSql).collect();
207
208        // Execute and collect rows
209        let mut rows_data: Vec<Vec<Value>> = Vec::new();
210        let mut row_iter = stmt.query(params_refs.as_slice())?;
211
212        let mut count = 0;
213        while let Some(row) = row_iter.next()? {
214            if count >= limit {
215                break;
216            }
217
218            let mut row_values: Vec<Value> = Vec::with_capacity(column_count);
219            for i in 0..column_count {
220                let value: Value = match row.get_ref(i)? {
221                    rusqlite::types::ValueRef::Null => Value::Null,
222                    rusqlite::types::ValueRef::Integer(i) => json!(i),
223                    rusqlite::types::ValueRef::Real(f) => json!(f),
224                    rusqlite::types::ValueRef::Text(s) => {
225                        json!(String::from_utf8_lossy(s).to_string())
226                    }
227                    rusqlite::types::ValueRef::Blob(b) => {
228                        json!(base64::Engine::encode(
229                            &base64::engine::general_purpose::STANDARD,
230                            b
231                        ))
232                    }
233                };
234                row_values.push(value);
235            }
236            rows_data.push(row_values);
237            count += 1;
238        }
239
240        // Check if there are more rows (for truncated flag)
241        let has_more = row_iter.next()?.is_some();
242
243        Ok((columns, rows_data, count, has_more))
244    })?;
245
246    let (columns, rows_data, row_count, truncated) = result;
247
248    // Format the output based on requested format
249    match format {
250        QueryFormat::Json => {
251            // Convert rows to objects with column names as keys
252            let rows: Vec<Value> = rows_data
253                .iter()
254                .map(|row| {
255                    let obj: serde_json::Map<String, Value> = columns
256                        .iter()
257                        .zip(row.iter())
258                        .map(|(col, val)| (col.clone(), val.clone()))
259                        .collect();
260                    Value::Object(obj)
261                })
262                .collect();
263
264            Ok(ToolResult::Json(json!({
265                "columns": columns,
266                "rows": rows,
267                "row_count": row_count,
268                "truncated": truncated,
269                "limit": limit
270            })))
271        }
272        QueryFormat::Csv => {
273            let mut csv = String::new();
274            // Header
275            csv.push_str(&columns.join(","));
276            csv.push('\n');
277            // Rows
278            for row in &rows_data {
279                let values: Vec<String> = row
280                    .iter()
281                    .map(|v| match v {
282                        Value::Null => String::new(),
283                        Value::String(s) => {
284                            // Escape quotes and wrap in quotes if contains comma or quotes
285                            if s.contains(',') || s.contains('"') || s.contains('\n') {
286                                format!("\"{}\"", s.replace('"', "\"\""))
287                            } else {
288                                s.clone()
289                            }
290                        }
291                        _ => v.to_string(),
292                    })
293                    .collect();
294                csv.push_str(&values.join(","));
295                csv.push('\n');
296            }
297
298            // CSV is raw text output
299            if truncated {
300                csv.push_str(&format!("\n# Results truncated at {} rows\n", limit));
301            }
302            Ok(ToolResult::Raw(csv))
303        }
304        QueryFormat::Markdown => {
305            let mut md = String::new();
306
307            if columns.is_empty() {
308                md.push_str("*No columns*\n");
309            } else {
310                // Header
311                md.push_str("| ");
312                md.push_str(&columns.join(" | "));
313                md.push_str(" |\n");
314
315                // Separator
316                md.push_str("| ");
317                md.push_str(
318                    &columns
319                        .iter()
320                        .map(|_| "---")
321                        .collect::<Vec<_>>()
322                        .join(" | "),
323                );
324                md.push_str(" |\n");
325
326                // Rows
327                for row in &rows_data {
328                    md.push_str("| ");
329                    let values: Vec<String> = row
330                        .iter()
331                        .map(|v| match v {
332                            Value::Null => String::from("*null*"),
333                            Value::String(s) => s.replace('|', "\\|"),
334                            _ => v.to_string(),
335                        })
336                        .collect();
337                    md.push_str(&values.join(" | "));
338                    md.push_str(" |\n");
339                }
340            }
341
342            if truncated {
343                md.push_str(&format!("\n*Results truncated at {} rows*\n", limit));
344            }
345
346            Ok(ToolResult::Raw(md))
347        }
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_validate_readonly_select() {
357        assert!(validate_readonly_sql("SELECT * FROM tasks").is_ok());
358        assert!(validate_readonly_sql("  SELECT id FROM tasks WHERE status = 'pending'  ").is_ok());
359        assert!(validate_readonly_sql("select count(*) from tasks").is_ok());
360    }
361
362    #[test]
363    fn test_validate_readonly_with_cte() {
364        assert!(validate_readonly_sql(
365            "WITH task_counts AS (SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status) SELECT * FROM task_counts"
366        ).is_ok());
367    }
368
369    #[test]
370    fn test_validate_readonly_rejects_insert() {
371        let result = validate_readonly_sql("INSERT INTO tasks (title) VALUES ('test')");
372        assert!(result.is_err());
373        assert!(result.unwrap_err().message.contains("INSERT"));
374    }
375
376    #[test]
377    fn test_validate_readonly_rejects_update() {
378        let result = validate_readonly_sql("UPDATE tasks SET status = 'done'");
379        assert!(result.is_err());
380    }
381
382    #[test]
383    fn test_validate_readonly_rejects_delete() {
384        let result = validate_readonly_sql("DELETE FROM tasks WHERE id = 'xxx'");
385        assert!(result.is_err());
386    }
387
388    #[test]
389    fn test_validate_readonly_rejects_drop() {
390        let result = validate_readonly_sql("DROP TABLE tasks");
391        assert!(result.is_err());
392    }
393
394    #[test]
395    fn test_validate_readonly_rejects_multiple_statements() {
396        let result = validate_readonly_sql("SELECT 1; DROP TABLE tasks;");
397        assert!(result.is_err());
398        assert!(result.unwrap_err().message.contains("Multiple"));
399    }
400
401    #[test]
402    fn test_validate_readonly_allows_column_names_with_keywords() {
403        // Column names like "deleted_at" or "updated_at" should be allowed
404        assert!(validate_readonly_sql("SELECT deleted_at FROM tasks").is_ok());
405        assert!(validate_readonly_sql("SELECT updated_at, created_at FROM tasks").is_ok());
406    }
407
408    #[test]
409    fn test_query_format_parsing() {
410        assert_eq!(QueryFormat::from_str("json"), Some(QueryFormat::Json));
411        assert_eq!(QueryFormat::from_str("JSON"), Some(QueryFormat::Json));
412        assert_eq!(QueryFormat::from_str("csv"), Some(QueryFormat::Csv));
413        assert_eq!(
414            QueryFormat::from_str("markdown"),
415            Some(QueryFormat::Markdown)
416        );
417        assert_eq!(QueryFormat::from_str("md"), Some(QueryFormat::Markdown));
418        assert_eq!(QueryFormat::from_str("invalid"), None);
419    }
420}