Skip to main content

task_graph_mcp/tools/
query.rs

1//! Read-only SQL query tool.
2//!
3//! Provides a `query` tool for executing read-only SQL queries against the database.
4//! This tool is intended for advanced users and debugging purposes.
5//!
6//! SECURITY: This tool requires user permission before execution and only allows
7//! SELECT statements. INSERT, UPDATE, DELETE, DROP, and other modifying statements
8//! are rejected.
9
10use super::{get_i32, get_string, get_string_array, make_tool};
11use crate::db::Database;
12use crate::error::{ErrorCode, ToolError};
13use crate::format::{OutputFormat, ToolResult};
14use anyhow::Result;
15use rmcp::model::{Tool, ToolAnnotations};
16use serde_json::{Value, json};
17use std::time::Duration;
18
19/// Default row limit for query results.
20const DEFAULT_ROW_LIMIT: i32 = 100;
21
22/// Maximum allowed row limit.
23const MAX_ROW_LIMIT: i32 = 1000;
24
25/// Query execution timeout in seconds.
26const QUERY_TIMEOUT_SECS: u64 = 5;
27
28/// Output format for query results.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum QueryFormat {
31    Json,
32    Csv,
33    Markdown,
34}
35
36impl QueryFormat {
37    fn from_str(s: &str) -> Option<Self> {
38        match s.to_lowercase().as_str() {
39            "json" => Some(Self::Json),
40            "csv" => Some(Self::Csv),
41            "markdown" | "md" => Some(Self::Markdown),
42            _ => None,
43        }
44    }
45}
46
47/// Get all query-related tools.
48pub fn get_tools() -> Vec<Tool> {
49    let mut tool = make_tool(
50        "query",
51        "Execute a read-only SQL query against the task database. REQUIRES USER PERMISSION. \
52         Only SELECT statements are allowed. Useful for custom queries, debugging, and \
53         advanced reporting. Returns columns, rows, and metadata.",
54        json!({
55            "sql": {
56                "type": "string",
57                "description": "SQL SELECT query to execute. Only SELECT statements are allowed."
58            },
59            "params": {
60                "type": "array",
61                "items": { "type": "string" },
62                "description": "Bind parameters for the query (use ? placeholders in SQL)"
63            },
64            "limit": {
65                "type": "integer",
66                "description": "Maximum number of rows to return (default: 100, max: 1000)"
67            },
68            "format": {
69                "type": "string",
70                "enum": ["json", "csv", "markdown"],
71                "description": "Output format for results (default: json)"
72            }
73        }),
74        vec!["sql"],
75    );
76
77    // Add annotations to indicate this is a read-only but potentially sensitive tool
78    // The destructiveHint is false because we only allow SELECT
79    // readOnlyHint is true because we don't modify data
80    tool.annotations = Some(ToolAnnotations {
81        title: Some("SQL Query".into()),
82        read_only_hint: Some(true),
83        destructive_hint: Some(false),
84        idempotent_hint: Some(true),
85        open_world_hint: Some(false),
86    });
87
88    vec![tool]
89}
90
91/// Validate that a SQL query is read-only (SELECT only).
92fn validate_readonly_sql(sql: &str) -> Result<(), ToolError> {
93    // Normalize whitespace and convert to uppercase for checking
94    let normalized = sql.trim().to_uppercase();
95
96    // Check for forbidden statements
97    let forbidden_prefixes = [
98        "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "TRUNCATE", "REPLACE", "UPSERT",
99        "MERGE", "GRANT", "REVOKE", "ATTACH", "DETACH", "VACUUM", "REINDEX", "ANALYZE",
100        "PRAGMA", // Some PRAGMAs can modify settings
101    ];
102
103    // Get the first word (statement type)
104    let first_word = normalized.split_whitespace().next().unwrap_or("");
105
106    if first_word != "SELECT" && first_word != "WITH" {
107        return Err(ToolError::new(
108            ErrorCode::InvalidFieldValue,
109            format!(
110                "Only SELECT queries are allowed. Got: {}",
111                if first_word.len() > 20 {
112                    &first_word[..20]
113                } else {
114                    first_word
115                }
116            ),
117        )
118        .with_field("sql"));
119    }
120
121    // Additional check for CTEs (WITH ... SELECT is OK, but WITH ... INSERT/UPDATE/DELETE is not)
122    if first_word == "WITH" {
123        // Look for modification keywords after WITH clause
124        for forbidden in &forbidden_prefixes {
125            // Check if the forbidden keyword appears as a standalone word (not in quotes or names)
126            let pattern = format!(r"\b{}\b", forbidden);
127            if let Ok(re) = regex_lite::Regex::new(&pattern)
128                && re.is_match(&normalized)
129            {
130                return Err(ToolError::new(
131                    ErrorCode::InvalidFieldValue,
132                    format!("{} statements are not allowed in queries", forbidden),
133                )
134                .with_field("sql"));
135            }
136        }
137    }
138
139    // Check for semicolons that might indicate multiple statements
140    // (SQLite allows this but we want to prevent injection)
141    let semicolon_count = sql.matches(';').count();
142    if semicolon_count > 1 {
143        return Err(ToolError::new(
144            ErrorCode::InvalidFieldValue,
145            "Multiple SQL statements are not allowed",
146        )
147        .with_field("sql"));
148    }
149
150    // Check for forbidden keywords anywhere in the query (for subqueries or injection attempts)
151    for forbidden in &forbidden_prefixes {
152        // Use word boundary matching to avoid false positives like "DELETED_AT"
153        let pattern = format!(r"\b{}\s+", forbidden);
154        if let Ok(re) = regex_lite::Regex::new(&pattern)
155            && re.is_match(&normalized)
156        {
157            return Err(ToolError::new(
158                ErrorCode::InvalidFieldValue,
159                format!("{} statements are not allowed", forbidden),
160            )
161            .with_field("sql"));
162        }
163    }
164
165    Ok(())
166}
167
168/// Execute a read-only SQL query.
169pub fn query(db: &Database, default_format: OutputFormat, args: Value) -> Result<ToolResult> {
170    let sql = get_string(&args, "sql").ok_or_else(|| ToolError::missing_field("sql"))?;
171
172    let params = get_string_array(&args, "params").unwrap_or_default();
173
174    let limit = get_i32(&args, "limit")
175        .map(|l| l.clamp(1, MAX_ROW_LIMIT))
176        .unwrap_or(DEFAULT_ROW_LIMIT);
177
178    // Use explicit format if provided, otherwise use config default
179    let format = get_string(&args, "format")
180        .and_then(|f| QueryFormat::from_str(&f))
181        .unwrap_or(match default_format {
182            OutputFormat::Json => QueryFormat::Json,
183            OutputFormat::Markdown => QueryFormat::Markdown,
184        });
185
186    // Validate the query is read-only
187    validate_readonly_sql(&sql)?;
188
189    // Execute the query with timeout
190    let result = db.with_conn(|conn| {
191        // Set a busy timeout for this connection
192        conn.busy_timeout(Duration::from_secs(QUERY_TIMEOUT_SECS))?;
193
194        // Prepare the statement
195        let mut stmt = conn.prepare(&sql)?;
196
197        // Get column names
198        let column_count = stmt.column_count();
199        let columns: Vec<String> = (0..column_count)
200            .map(|i| stmt.column_name(i).unwrap_or("?").to_string())
201            .collect();
202
203        // Bind parameters
204        let params_refs: Vec<&dyn rusqlite::ToSql> =
205            params.iter().map(|s| s as &dyn rusqlite::ToSql).collect();
206
207        // Execute and collect rows
208        let mut rows_data: Vec<Vec<Value>> = Vec::new();
209        let mut row_iter = stmt.query(params_refs.as_slice())?;
210
211        let mut count = 0;
212        while let Some(row) = row_iter.next()? {
213            if count >= limit {
214                break;
215            }
216
217            let mut row_values: Vec<Value> = Vec::with_capacity(column_count);
218            for i in 0..column_count {
219                let value: Value = match row.get_ref(i)? {
220                    rusqlite::types::ValueRef::Null => Value::Null,
221                    rusqlite::types::ValueRef::Integer(i) => json!(i),
222                    rusqlite::types::ValueRef::Real(f) => json!(f),
223                    rusqlite::types::ValueRef::Text(s) => {
224                        json!(String::from_utf8_lossy(s).to_string())
225                    }
226                    rusqlite::types::ValueRef::Blob(b) => {
227                        json!(base64::Engine::encode(
228                            &base64::engine::general_purpose::STANDARD,
229                            b
230                        ))
231                    }
232                };
233                row_values.push(value);
234            }
235            rows_data.push(row_values);
236            count += 1;
237        }
238
239        // Check if there are more rows (for truncated flag)
240        let has_more = row_iter.next()?.is_some();
241
242        Ok((columns, rows_data, count, has_more))
243    })?;
244
245    let (columns, rows_data, row_count, truncated) = result;
246
247    // Format the output based on requested format
248    match format {
249        QueryFormat::Json => {
250            // Convert rows to objects with column names as keys
251            let rows: Vec<Value> = rows_data
252                .iter()
253                .map(|row| {
254                    let obj: serde_json::Map<String, Value> = columns
255                        .iter()
256                        .zip(row.iter())
257                        .map(|(col, val)| (col.clone(), val.clone()))
258                        .collect();
259                    Value::Object(obj)
260                })
261                .collect();
262
263            Ok(ToolResult::Json(json!({
264                "columns": columns,
265                "rows": rows,
266                "row_count": row_count,
267                "truncated": truncated,
268                "limit": limit
269            })))
270        }
271        QueryFormat::Csv => {
272            let mut csv = String::new();
273            // Header
274            csv.push_str(&columns.join(","));
275            csv.push('\n');
276            // Rows
277            for row in &rows_data {
278                let values: Vec<String> = row
279                    .iter()
280                    .map(|v| match v {
281                        Value::Null => String::new(),
282                        Value::String(s) => {
283                            // Escape quotes and wrap in quotes if contains comma or quotes
284                            if s.contains(',') || s.contains('"') || s.contains('\n') {
285                                format!("\"{}\"", s.replace('"', "\"\""))
286                            } else {
287                                s.clone()
288                            }
289                        }
290                        _ => v.to_string(),
291                    })
292                    .collect();
293                csv.push_str(&values.join(","));
294                csv.push('\n');
295            }
296
297            // CSV is raw text output
298            if truncated {
299                csv.push_str(&format!("\n# Results truncated at {} rows\n", limit));
300            }
301            Ok(ToolResult::Raw(csv))
302        }
303        QueryFormat::Markdown => {
304            let mut md = String::new();
305
306            if columns.is_empty() {
307                md.push_str("*No columns*\n");
308            } else {
309                // Header
310                md.push_str("| ");
311                md.push_str(&columns.join(" | "));
312                md.push_str(" |\n");
313
314                // Separator
315                md.push_str("| ");
316                md.push_str(
317                    &columns
318                        .iter()
319                        .map(|_| "---")
320                        .collect::<Vec<_>>()
321                        .join(" | "),
322                );
323                md.push_str(" |\n");
324
325                // Rows
326                for row in &rows_data {
327                    md.push_str("| ");
328                    let values: Vec<String> = row
329                        .iter()
330                        .map(|v| match v {
331                            Value::Null => String::from("*null*"),
332                            Value::String(s) => s.replace('|', "\\|"),
333                            _ => v.to_string(),
334                        })
335                        .collect();
336                    md.push_str(&values.join(" | "));
337                    md.push_str(" |\n");
338                }
339            }
340
341            if truncated {
342                md.push_str(&format!("\n*Results truncated at {} rows*\n", limit));
343            }
344
345            Ok(ToolResult::Raw(md))
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_validate_readonly_select() {
356        assert!(validate_readonly_sql("SELECT * FROM tasks").is_ok());
357        assert!(validate_readonly_sql("  SELECT id FROM tasks WHERE status = 'pending'  ").is_ok());
358        assert!(validate_readonly_sql("select count(*) from tasks").is_ok());
359    }
360
361    #[test]
362    fn test_validate_readonly_with_cte() {
363        assert!(validate_readonly_sql(
364            "WITH task_counts AS (SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status) SELECT * FROM task_counts"
365        ).is_ok());
366    }
367
368    #[test]
369    fn test_validate_readonly_rejects_insert() {
370        let result = validate_readonly_sql("INSERT INTO tasks (title) VALUES ('test')");
371        assert!(result.is_err());
372        assert!(result.unwrap_err().message.contains("INSERT"));
373    }
374
375    #[test]
376    fn test_validate_readonly_rejects_update() {
377        let result = validate_readonly_sql("UPDATE tasks SET status = 'done'");
378        assert!(result.is_err());
379    }
380
381    #[test]
382    fn test_validate_readonly_rejects_delete() {
383        let result = validate_readonly_sql("DELETE FROM tasks WHERE id = 'xxx'");
384        assert!(result.is_err());
385    }
386
387    #[test]
388    fn test_validate_readonly_rejects_drop() {
389        let result = validate_readonly_sql("DROP TABLE tasks");
390        assert!(result.is_err());
391    }
392
393    #[test]
394    fn test_validate_readonly_rejects_multiple_statements() {
395        let result = validate_readonly_sql("SELECT 1; DROP TABLE tasks;");
396        assert!(result.is_err());
397        assert!(result.unwrap_err().message.contains("Multiple"));
398    }
399
400    #[test]
401    fn test_validate_readonly_allows_column_names_with_keywords() {
402        // Column names like "deleted_at" or "updated_at" should be allowed
403        assert!(validate_readonly_sql("SELECT deleted_at FROM tasks").is_ok());
404        assert!(validate_readonly_sql("SELECT updated_at, created_at FROM tasks").is_ok());
405    }
406
407    #[test]
408    fn test_query_format_parsing() {
409        assert_eq!(QueryFormat::from_str("json"), Some(QueryFormat::Json));
410        assert_eq!(QueryFormat::from_str("JSON"), Some(QueryFormat::Json));
411        assert_eq!(QueryFormat::from_str("csv"), Some(QueryFormat::Csv));
412        assert_eq!(
413            QueryFormat::from_str("markdown"),
414            Some(QueryFormat::Markdown)
415        );
416        assert_eq!(QueryFormat::from_str("md"), Some(QueryFormat::Markdown));
417        assert_eq!(QueryFormat::from_str("invalid"), None);
418    }
419}