Skip to main content

task_graph_mcp/tools/
query.rs

1//! Read-only SQL query tool.
2//!
3//! Provides a `query` tool for executing read-only SQL queries against the database.
4//! This tool is intended for advanced users and debugging purposes.
5//!
6//! SECURITY: This tool requires user permission before execution and only allows
7//! SELECT statements. INSERT, UPDATE, DELETE, DROP, and other modifying statements
8//! are rejected.
9
10use super::{get_i32, get_string, get_string_array, make_tool};
11use crate::db::Database;
12use crate::error::{ErrorCode, ToolError};
13use crate::format::{OutputFormat, ToolResult};
14use anyhow::Result;
15use rmcp::model::{Tool, ToolAnnotations};
16use serde_json::{json, Value};
17use std::time::Duration;
18
19/// Default row limit for query results.
20const DEFAULT_ROW_LIMIT: i32 = 100;
21
22/// Maximum allowed row limit.
23const MAX_ROW_LIMIT: i32 = 1000;
24
25/// Query execution timeout in seconds.
26const QUERY_TIMEOUT_SECS: u64 = 5;
27
28/// Output format for query results.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum QueryFormat {
31    Json,
32    Csv,
33    Markdown,
34}
35
36impl QueryFormat {
37    fn from_str(s: &str) -> Option<Self> {
38        match s.to_lowercase().as_str() {
39            "json" => Some(Self::Json),
40            "csv" => Some(Self::Csv),
41            "markdown" | "md" => Some(Self::Markdown),
42            _ => None,
43        }
44    }
45}
46
47/// Get all query-related tools.
48pub fn get_tools() -> Vec<Tool> {
49    let mut tool = make_tool(
50        "query",
51        "Execute a read-only SQL query against the task database. REQUIRES USER PERMISSION. \
52         Only SELECT statements are allowed. Useful for custom queries, debugging, and \
53         advanced reporting. Returns columns, rows, and metadata.",
54        json!({
55            "sql": {
56                "type": "string",
57                "description": "SQL SELECT query to execute. Only SELECT statements are allowed."
58            },
59            "params": {
60                "type": "array",
61                "items": { "type": "string" },
62                "description": "Bind parameters for the query (use ? placeholders in SQL)"
63            },
64            "limit": {
65                "type": "integer",
66                "description": "Maximum number of rows to return (default: 100, max: 1000)"
67            },
68            "format": {
69                "type": "string",
70                "enum": ["json", "csv", "markdown"],
71                "description": "Output format for results (default: json)"
72            }
73        }),
74        vec!["sql"],
75    );
76
77    // Add annotations to indicate this is a read-only but potentially sensitive tool
78    // The destructiveHint is false because we only allow SELECT
79    // readOnlyHint is true because we don't modify data
80    tool.annotations = Some(ToolAnnotations {
81        title: Some("SQL Query".into()),
82        read_only_hint: Some(true),
83        destructive_hint: Some(false),
84        idempotent_hint: Some(true),
85        open_world_hint: Some(false),
86    });
87
88    vec![tool]
89}
90
91/// Validate that a SQL query is read-only (SELECT only).
92fn validate_readonly_sql(sql: &str) -> Result<(), ToolError> {
93    // Normalize whitespace and convert to uppercase for checking
94    let normalized = sql.trim().to_uppercase();
95
96    // Check for forbidden statements
97    let forbidden_prefixes = [
98        "INSERT",
99        "UPDATE",
100        "DELETE",
101        "DROP",
102        "CREATE",
103        "ALTER",
104        "TRUNCATE",
105        "REPLACE",
106        "UPSERT",
107        "MERGE",
108        "GRANT",
109        "REVOKE",
110        "ATTACH",
111        "DETACH",
112        "VACUUM",
113        "REINDEX",
114        "ANALYZE",
115        "PRAGMA",  // Some PRAGMAs can modify settings
116    ];
117
118    // Get the first word (statement type)
119    let first_word = normalized.split_whitespace().next().unwrap_or("");
120
121    if first_word != "SELECT" && first_word != "WITH" {
122        return Err(ToolError::new(
123            ErrorCode::InvalidFieldValue,
124            format!(
125                "Only SELECT queries are allowed. Got: {}",
126                if first_word.len() > 20 {
127                    &first_word[..20]
128                } else {
129                    first_word
130                }
131            ),
132        )
133        .with_field("sql"));
134    }
135
136    // Additional check for CTEs (WITH ... SELECT is OK, but WITH ... INSERT/UPDATE/DELETE is not)
137    if first_word == "WITH" {
138        // Look for modification keywords after WITH clause
139        for forbidden in &forbidden_prefixes {
140            // Check if the forbidden keyword appears as a standalone word (not in quotes or names)
141            let pattern = format!(r"\b{}\b", forbidden);
142            if let Ok(re) = regex_lite::Regex::new(&pattern)
143                && re.is_match(&normalized)
144            {
145                return Err(ToolError::new(
146                    ErrorCode::InvalidFieldValue,
147                    format!("{} statements are not allowed in queries", forbidden),
148                )
149                .with_field("sql"));
150            }
151        }
152    }
153
154    // Check for semicolons that might indicate multiple statements
155    // (SQLite allows this but we want to prevent injection)
156    let semicolon_count = sql.matches(';').count();
157    if semicolon_count > 1 {
158        return Err(ToolError::new(
159            ErrorCode::InvalidFieldValue,
160            "Multiple SQL statements are not allowed",
161        )
162        .with_field("sql"));
163    }
164
165    // Check for forbidden keywords anywhere in the query (for subqueries or injection attempts)
166    for forbidden in &forbidden_prefixes {
167        // Use word boundary matching to avoid false positives like "DELETED_AT"
168        let pattern = format!(r"\b{}\s+", forbidden);
169        if let Ok(re) = regex_lite::Regex::new(&pattern)
170            && re.is_match(&normalized)
171        {
172            return Err(ToolError::new(
173                ErrorCode::InvalidFieldValue,
174                format!("{} statements are not allowed", forbidden),
175            )
176            .with_field("sql"));
177        }
178    }
179
180    Ok(())
181}
182
183/// Execute a read-only SQL query.
184pub fn query(db: &Database, default_format: OutputFormat, args: Value) -> Result<ToolResult> {
185    let sql = get_string(&args, "sql").ok_or_else(|| ToolError::missing_field("sql"))?;
186
187    let params = get_string_array(&args, "params").unwrap_or_default();
188
189    let limit = get_i32(&args, "limit")
190        .map(|l| l.clamp(1, MAX_ROW_LIMIT))
191        .unwrap_or(DEFAULT_ROW_LIMIT);
192
193    // Use explicit format if provided, otherwise use config default
194    let format = get_string(&args, "format")
195        .and_then(|f| QueryFormat::from_str(&f))
196        .unwrap_or_else(|| match default_format {
197            OutputFormat::Json => QueryFormat::Json,
198            OutputFormat::Markdown => QueryFormat::Markdown,
199        });
200
201    // Validate the query is read-only
202    validate_readonly_sql(&sql)?;
203
204    // Execute the query with timeout
205    let result = db.with_conn(|conn| {
206        // Set a busy timeout for this connection
207        conn.busy_timeout(Duration::from_secs(QUERY_TIMEOUT_SECS))?;
208
209        // Prepare the statement
210        let mut stmt = conn.prepare(&sql)?;
211
212        // Get column names
213        let column_count = stmt.column_count();
214        let columns: Vec<String> = (0..column_count)
215            .map(|i| stmt.column_name(i).unwrap_or("?").to_string())
216            .collect();
217
218        // Bind parameters
219        let params_refs: Vec<&dyn rusqlite::ToSql> = params
220            .iter()
221            .map(|s| s as &dyn rusqlite::ToSql)
222            .collect();
223
224        // Execute and collect rows
225        let mut rows_data: Vec<Vec<Value>> = Vec::new();
226        let mut row_iter = stmt.query(params_refs.as_slice())?;
227
228        let mut count = 0;
229        while let Some(row) = row_iter.next()? {
230            if count >= limit {
231                break;
232            }
233
234            let mut row_values: Vec<Value> = Vec::with_capacity(column_count);
235            for i in 0..column_count {
236                let value: Value = match row.get_ref(i)? {
237                    rusqlite::types::ValueRef::Null => Value::Null,
238                    rusqlite::types::ValueRef::Integer(i) => json!(i),
239                    rusqlite::types::ValueRef::Real(f) => json!(f),
240                    rusqlite::types::ValueRef::Text(s) => {
241                        json!(String::from_utf8_lossy(s).to_string())
242                    }
243                    rusqlite::types::ValueRef::Blob(b) => {
244                        json!(base64::Engine::encode(
245                            &base64::engine::general_purpose::STANDARD,
246                            b
247                        ))
248                    }
249                };
250                row_values.push(value);
251            }
252            rows_data.push(row_values);
253            count += 1;
254        }
255
256        // Check if there are more rows (for truncated flag)
257        let has_more = row_iter.next()?.is_some();
258
259        Ok((columns, rows_data, count, has_more))
260    })?;
261
262    let (columns, rows_data, row_count, truncated) = result;
263
264    // Format the output based on requested format
265    match format {
266        QueryFormat::Json => {
267            // Convert rows to objects with column names as keys
268            let rows: Vec<Value> = rows_data
269                .iter()
270                .map(|row| {
271                    let obj: serde_json::Map<String, Value> = columns
272                        .iter()
273                        .zip(row.iter())
274                        .map(|(col, val)| (col.clone(), val.clone()))
275                        .collect();
276                    Value::Object(obj)
277                })
278                .collect();
279
280            Ok(ToolResult::Json(json!({
281                "columns": columns,
282                "rows": rows,
283                "row_count": row_count,
284                "truncated": truncated,
285                "limit": limit
286            })))
287        }
288        QueryFormat::Csv => {
289            let mut csv = String::new();
290            // Header
291            csv.push_str(&columns.join(","));
292            csv.push('\n');
293            // Rows
294            for row in &rows_data {
295                let values: Vec<String> = row
296                    .iter()
297                    .map(|v| match v {
298                        Value::Null => String::new(),
299                        Value::String(s) => {
300                            // Escape quotes and wrap in quotes if contains comma or quotes
301                            if s.contains(',') || s.contains('"') || s.contains('\n') {
302                                format!("\"{}\"", s.replace('"', "\"\""))
303                            } else {
304                                s.clone()
305                            }
306                        }
307                        _ => v.to_string(),
308                    })
309                    .collect();
310                csv.push_str(&values.join(","));
311                csv.push('\n');
312            }
313
314            // CSV is raw text output
315            if truncated {
316                csv.push_str(&format!("\n# Results truncated at {} rows\n", limit));
317            }
318            Ok(ToolResult::Raw(csv))
319        }
320        QueryFormat::Markdown => {
321            let mut md = String::new();
322
323            if columns.is_empty() {
324                md.push_str("*No columns*\n");
325            } else {
326                // Header
327                md.push_str("| ");
328                md.push_str(&columns.join(" | "));
329                md.push_str(" |\n");
330
331                // Separator
332                md.push_str("| ");
333                md.push_str(
334                    &columns
335                        .iter()
336                        .map(|_| "---")
337                        .collect::<Vec<_>>()
338                        .join(" | "),
339                );
340                md.push_str(" |\n");
341
342                // Rows
343                for row in &rows_data {
344                    md.push_str("| ");
345                    let values: Vec<String> = row
346                        .iter()
347                        .map(|v| match v {
348                            Value::Null => String::from("*null*"),
349                            Value::String(s) => s.replace('|', "\\|"),
350                            _ => v.to_string(),
351                        })
352                        .collect();
353                    md.push_str(&values.join(" | "));
354                    md.push_str(" |\n");
355                }
356            }
357
358            if truncated {
359                md.push_str(&format!("\n*Results truncated at {} rows*\n", limit));
360            }
361
362            Ok(ToolResult::Raw(md))
363        }
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_validate_readonly_select() {
373        assert!(validate_readonly_sql("SELECT * FROM tasks").is_ok());
374        assert!(validate_readonly_sql("  SELECT id FROM tasks WHERE status = 'pending'  ").is_ok());
375        assert!(validate_readonly_sql("select count(*) from tasks").is_ok());
376    }
377
378    #[test]
379    fn test_validate_readonly_with_cte() {
380        assert!(validate_readonly_sql(
381            "WITH task_counts AS (SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status) SELECT * FROM task_counts"
382        ).is_ok());
383    }
384
385    #[test]
386    fn test_validate_readonly_rejects_insert() {
387        let result = validate_readonly_sql("INSERT INTO tasks (title) VALUES ('test')");
388        assert!(result.is_err());
389        assert!(result.unwrap_err().message.contains("INSERT"));
390    }
391
392    #[test]
393    fn test_validate_readonly_rejects_update() {
394        let result = validate_readonly_sql("UPDATE tasks SET status = 'done'");
395        assert!(result.is_err());
396    }
397
398    #[test]
399    fn test_validate_readonly_rejects_delete() {
400        let result = validate_readonly_sql("DELETE FROM tasks WHERE id = 'xxx'");
401        assert!(result.is_err());
402    }
403
404    #[test]
405    fn test_validate_readonly_rejects_drop() {
406        let result = validate_readonly_sql("DROP TABLE tasks");
407        assert!(result.is_err());
408    }
409
410    #[test]
411    fn test_validate_readonly_rejects_multiple_statements() {
412        let result = validate_readonly_sql("SELECT 1; DROP TABLE tasks;");
413        assert!(result.is_err());
414        assert!(result.unwrap_err().message.contains("Multiple"));
415    }
416
417    #[test]
418    fn test_validate_readonly_allows_column_names_with_keywords() {
419        // Column names like "deleted_at" or "updated_at" should be allowed
420        assert!(validate_readonly_sql("SELECT deleted_at FROM tasks").is_ok());
421        assert!(validate_readonly_sql("SELECT updated_at, created_at FROM tasks").is_ok());
422    }
423
424    #[test]
425    fn test_query_format_parsing() {
426        assert_eq!(QueryFormat::from_str("json"), Some(QueryFormat::Json));
427        assert_eq!(QueryFormat::from_str("JSON"), Some(QueryFormat::Json));
428        assert_eq!(QueryFormat::from_str("csv"), Some(QueryFormat::Csv));
429        assert_eq!(QueryFormat::from_str("markdown"), Some(QueryFormat::Markdown));
430        assert_eq!(QueryFormat::from_str("md"), Some(QueryFormat::Markdown));
431        assert_eq!(QueryFormat::from_str("invalid"), None);
432    }
433}