Skip to main content

synaptic_sqltoolkit/
lib.rs

1//! SQL database toolkit for the Synaptic framework.
2//!
3//! Provides a set of read-only SQL tools for use with LLM agents:
4//!
5//! - [`ListTablesTool`] — lists all tables in the database
6//! - [`DescribeTableTool`] — returns column info for a table
7//! - [`ExecuteQueryTool`] — runs a SELECT query and returns results as JSON
8//!
9//! # Quick start
10//!
11//! ```rust,ignore
12//! use sqlx::sqlite::SqlitePoolOptions;
13//! use synaptic_sqltoolkit::SqlToolkit;
14//! use synaptic_tools::ToolRegistry;
15//! use std::sync::Arc;
16//!
17//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
18//! let pool = SqlitePoolOptions::new()
19//!     .connect("sqlite::memory:")
20//!     .await?;
21//!
22//! let toolkit = SqlToolkit::sqlite(pool);
23//! let registry = ToolRegistry::new();
24//! for tool in toolkit.tools() {
25//!     registry.register(tool)?;
26//! }
27//! # Ok(())
28//! # }
29//! ```
30
31use async_trait::async_trait;
32use serde_json::{json, Value};
33use sqlx::{Column, Row, SqlitePool, TypeInfo};
34use std::sync::Arc;
35use synaptic_core::{SynapticError, Tool};
36
37// ---------------------------------------------------------------------------
38// SqlToolkit
39// ---------------------------------------------------------------------------
40
41/// A toolkit that provides SQL tools for agent use.
42pub struct SqlToolkit {
43    pool: SqlitePool,
44}
45
46impl SqlToolkit {
47    /// Create a toolkit backed by a SQLite pool.
48    pub fn sqlite(pool: SqlitePool) -> Self {
49        Self { pool }
50    }
51
52    /// Return the set of tools provided by this toolkit.
53    pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
54        vec![
55            Arc::new(ListTablesTool {
56                pool: self.pool.clone(),
57            }),
58            Arc::new(DescribeTableTool {
59                pool: self.pool.clone(),
60            }),
61            Arc::new(ExecuteQueryTool {
62                pool: self.pool.clone(),
63            }),
64        ]
65    }
66}
67
68// ---------------------------------------------------------------------------
69// ListTablesTool
70// ---------------------------------------------------------------------------
71
72/// Tool that lists all tables in the SQLite database.
73pub struct ListTablesTool {
74    pool: SqlitePool,
75}
76
77#[async_trait]
78impl Tool for ListTablesTool {
79    fn name(&self) -> &'static str {
80        "sql_list_tables"
81    }
82
83    fn description(&self) -> &'static str {
84        "List all tables available in the SQL database. Returns a JSON array of table names."
85    }
86
87    fn parameters(&self) -> Option<Value> {
88        Some(json!({
89            "type": "object",
90            "properties": {},
91            "required": []
92        }))
93    }
94
95    async fn call(&self, _args: Value) -> Result<Value, SynapticError> {
96        let rows = sqlx::query(
97            "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
98        )
99        .fetch_all(&self.pool)
100        .await
101        .map_err(|e| SynapticError::Tool(format!("ListTables error: {e}")))?;
102
103        let tables: Vec<String> = rows.iter().map(|r| r.get::<String, _>("name")).collect();
104
105        Ok(json!({ "tables": tables }))
106    }
107}
108
109// ---------------------------------------------------------------------------
110// DescribeTableTool
111// ---------------------------------------------------------------------------
112
113/// Tool that returns the schema of a specific table.
114pub struct DescribeTableTool {
115    pool: SqlitePool,
116}
117
118/// Validates that an identifier contains only safe characters (alphanumeric + underscore).
119fn is_safe_identifier(name: &str) -> bool {
120    !name.is_empty() && name.chars().all(|c| c.is_alphanumeric() || c == '_')
121}
122
123#[async_trait]
124impl Tool for DescribeTableTool {
125    fn name(&self) -> &'static str {
126        "sql_describe_table"
127    }
128
129    fn description(&self) -> &'static str {
130        "Describe the schema of a SQL table. Returns column names, types, and constraints."
131    }
132
133    fn parameters(&self) -> Option<Value> {
134        Some(json!({
135            "type": "object",
136            "properties": {
137                "table_name": {
138                    "type": "string",
139                    "description": "The name of the table to describe"
140                }
141            },
142            "required": ["table_name"]
143        }))
144    }
145
146    async fn call(&self, args: Value) -> Result<Value, SynapticError> {
147        let table_name = args["table_name"]
148            .as_str()
149            .ok_or_else(|| SynapticError::Tool("missing 'table_name' parameter".to_string()))?;
150
151        if !is_safe_identifier(table_name) {
152            return Err(SynapticError::Tool(format!(
153                "Invalid table name: '{table_name}'. Only alphanumeric characters and underscores are allowed."
154            )));
155        }
156
157        let sql = format!("PRAGMA table_info({table_name})");
158        let rows = sqlx::query(&sql)
159            .fetch_all(&self.pool)
160            .await
161            .map_err(|e| SynapticError::Tool(format!("DescribeTable error: {e}")))?;
162
163        if rows.is_empty() {
164            return Err(SynapticError::Tool(format!(
165                "Table '{table_name}' does not exist."
166            )));
167        }
168
169        let columns: Vec<Value> = rows
170            .iter()
171            .map(|r| {
172                json!({
173                    "cid": r.get::<i64, _>("cid"),
174                    "name": r.get::<String, _>("name"),
175                    "type": r.get::<String, _>("type"),
176                    "not_null": r.get::<bool, _>("notnull"),
177                    "primary_key": r.get::<i64, _>("pk") > 0,
178                })
179            })
180            .collect();
181
182        Ok(json!({
183            "table": table_name,
184            "columns": columns,
185        }))
186    }
187}
188
189// ---------------------------------------------------------------------------
190// ExecuteQueryTool
191// ---------------------------------------------------------------------------
192
193/// Tool that executes a read-only SQL SELECT query.
194pub struct ExecuteQueryTool {
195    pool: SqlitePool,
196}
197
198#[async_trait]
199impl Tool for ExecuteQueryTool {
200    fn name(&self) -> &'static str {
201        "sql_execute_query"
202    }
203
204    fn description(&self) -> &'static str {
205        "Execute a read-only SQL SELECT query and return results as JSON. \
206         Only SELECT statements are allowed for safety. \
207         Returns an array of objects, one per row."
208    }
209
210    fn parameters(&self) -> Option<Value> {
211        Some(json!({
212            "type": "object",
213            "properties": {
214                "query": {
215                    "type": "string",
216                    "description": "A SQL SELECT query to execute. Must start with SELECT."
217                }
218            },
219            "required": ["query"]
220        }))
221    }
222
223    async fn call(&self, args: Value) -> Result<Value, SynapticError> {
224        let query = args["query"]
225            .as_str()
226            .ok_or_else(|| SynapticError::Tool("missing 'query' parameter".to_string()))?;
227
228        // Safety: only allow SELECT statements
229        let trimmed = query.trim_start().to_uppercase();
230        if !trimmed.starts_with("SELECT") {
231            return Err(SynapticError::Tool(
232                "Only SELECT queries are allowed for safety.".to_string(),
233            ));
234        }
235
236        let rows = sqlx::query(query)
237            .fetch_all(&self.pool)
238            .await
239            .map_err(|e| SynapticError::Tool(format!("Query execution error: {e}")))?;
240
241        let results: Vec<Value> = rows
242            .iter()
243            .map(|row| {
244                let mut obj = serde_json::Map::new();
245                for (i, col) in row.columns().iter().enumerate() {
246                    let name = col.name().to_string();
247                    let type_name = col.type_info().name();
248                    let val: Value = match type_name {
249                        "INTEGER" | "INT" | "BIGINT" => {
250                            if let Ok(v) = row.try_get::<i64, _>(i) {
251                                json!(v)
252                            } else {
253                                Value::Null
254                            }
255                        }
256                        "REAL" | "FLOAT" | "DOUBLE" => {
257                            if let Ok(v) = row.try_get::<f64, _>(i) {
258                                json!(v)
259                            } else {
260                                Value::Null
261                            }
262                        }
263                        "BOOLEAN" | "BOOL" => {
264                            if let Ok(v) = row.try_get::<bool, _>(i) {
265                                json!(v)
266                            } else {
267                                Value::Null
268                            }
269                        }
270                        _ => {
271                            // Default to string for TEXT, BLOB, NULL, etc.
272                            if let Ok(v) = row.try_get::<String, _>(i) {
273                                json!(v)
274                            } else {
275                                Value::Null
276                            }
277                        }
278                    };
279                    obj.insert(name, val);
280                }
281                Value::Object(obj)
282            })
283            .collect();
284
285        Ok(json!({
286            "rows": results,
287            "row_count": results.len(),
288        }))
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use sqlx::sqlite::SqlitePoolOptions;
296
297    async fn test_pool() -> SqlitePool {
298        SqlitePoolOptions::new()
299            .connect("sqlite::memory:")
300            .await
301            .expect("in-memory SQLite")
302    }
303
304    #[test]
305    fn is_safe_identifier_allows_valid() {
306        assert!(is_safe_identifier("users"));
307        assert!(is_safe_identifier("my_table_123"));
308        assert!(is_safe_identifier("ABC"));
309    }
310
311    #[test]
312    fn is_safe_identifier_blocks_injection() {
313        assert!(!is_safe_identifier("users; DROP TABLE users--"));
314        assert!(!is_safe_identifier("users--"));
315        assert!(!is_safe_identifier(""));
316        assert!(!is_safe_identifier("tab le"));
317    }
318
319    #[tokio::test]
320    async fn list_tables_empty_db() {
321        let pool = test_pool().await;
322        let tool = ListTablesTool { pool };
323        let result = tool.call(json!({})).await.unwrap();
324        assert_eq!(result["tables"], json!([]));
325    }
326
327    #[tokio::test]
328    async fn list_tables_with_data() {
329        let pool = test_pool().await;
330        sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
331            .execute(&pool)
332            .await
333            .unwrap();
334        let tool = ListTablesTool { pool: pool.clone() };
335        let result = tool.call(json!({})).await.unwrap();
336        let tables = result["tables"].as_array().unwrap();
337        assert_eq!(tables.len(), 1);
338        assert_eq!(tables[0], "users");
339    }
340
341    #[tokio::test]
342    async fn describe_table() {
343        let pool = test_pool().await;
344        sqlx::query(
345            "CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT NOT NULL, price REAL)",
346        )
347        .execute(&pool)
348        .await
349        .unwrap();
350        let tool = DescribeTableTool { pool: pool.clone() };
351        let result = tool.call(json!({"table_name": "products"})).await.unwrap();
352        let cols = result["columns"].as_array().unwrap();
353        assert_eq!(cols.len(), 3);
354        assert_eq!(cols[0]["name"], "id");
355        assert_eq!(cols[0]["primary_key"], true);
356    }
357
358    #[tokio::test]
359    async fn execute_select_query() {
360        let pool = test_pool().await;
361        sqlx::query("CREATE TABLE items (id INTEGER, label TEXT)")
362            .execute(&pool)
363            .await
364            .unwrap();
365        sqlx::query("INSERT INTO items VALUES (1, 'alpha'), (2, 'beta')")
366            .execute(&pool)
367            .await
368            .unwrap();
369        let tool = ExecuteQueryTool { pool: pool.clone() };
370        let result = tool
371            .call(json!({"query": "SELECT id, label FROM items ORDER BY id"}))
372            .await
373            .unwrap();
374        assert_eq!(result["row_count"], 2);
375        assert_eq!(result["rows"][0]["label"], "alpha");
376    }
377
378    #[tokio::test]
379    async fn execute_non_select_rejected() {
380        let pool = test_pool().await;
381        let tool = ExecuteQueryTool { pool };
382        let err = tool
383            .call(json!({"query": "DROP TABLE users"}))
384            .await
385            .unwrap_err();
386        assert!(matches!(err, SynapticError::Tool(_)));
387    }
388}