synaptic_sqltoolkit/
lib.rs1use async_trait::async_trait;
32use serde_json::{json, Value};
33use sqlx::{Column, Row, SqlitePool, TypeInfo};
34use std::sync::Arc;
35use synaptic_core::{SynapticError, Tool};
36
37pub struct SqlToolkit {
43 pool: SqlitePool,
44}
45
46impl SqlToolkit {
47 pub fn sqlite(pool: SqlitePool) -> Self {
49 Self { pool }
50 }
51
52 pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
54 vec![
55 Arc::new(ListTablesTool {
56 pool: self.pool.clone(),
57 }),
58 Arc::new(DescribeTableTool {
59 pool: self.pool.clone(),
60 }),
61 Arc::new(ExecuteQueryTool {
62 pool: self.pool.clone(),
63 }),
64 ]
65 }
66}
67
68pub struct ListTablesTool {
74 pool: SqlitePool,
75}
76
77#[async_trait]
78impl Tool for ListTablesTool {
79 fn name(&self) -> &'static str {
80 "sql_list_tables"
81 }
82
83 fn description(&self) -> &'static str {
84 "List all tables available in the SQL database. Returns a JSON array of table names."
85 }
86
87 fn parameters(&self) -> Option<Value> {
88 Some(json!({
89 "type": "object",
90 "properties": {},
91 "required": []
92 }))
93 }
94
95 async fn call(&self, _args: Value) -> Result<Value, SynapticError> {
96 let rows = sqlx::query(
97 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
98 )
99 .fetch_all(&self.pool)
100 .await
101 .map_err(|e| SynapticError::Tool(format!("ListTables error: {e}")))?;
102
103 let tables: Vec<String> = rows.iter().map(|r| r.get::<String, _>("name")).collect();
104
105 Ok(json!({ "tables": tables }))
106 }
107}
108
109pub struct DescribeTableTool {
115 pool: SqlitePool,
116}
117
118fn is_safe_identifier(name: &str) -> bool {
120 !name.is_empty() && name.chars().all(|c| c.is_alphanumeric() || c == '_')
121}
122
123#[async_trait]
124impl Tool for DescribeTableTool {
125 fn name(&self) -> &'static str {
126 "sql_describe_table"
127 }
128
129 fn description(&self) -> &'static str {
130 "Describe the schema of a SQL table. Returns column names, types, and constraints."
131 }
132
133 fn parameters(&self) -> Option<Value> {
134 Some(json!({
135 "type": "object",
136 "properties": {
137 "table_name": {
138 "type": "string",
139 "description": "The name of the table to describe"
140 }
141 },
142 "required": ["table_name"]
143 }))
144 }
145
146 async fn call(&self, args: Value) -> Result<Value, SynapticError> {
147 let table_name = args["table_name"]
148 .as_str()
149 .ok_or_else(|| SynapticError::Tool("missing 'table_name' parameter".to_string()))?;
150
151 if !is_safe_identifier(table_name) {
152 return Err(SynapticError::Tool(format!(
153 "Invalid table name: '{table_name}'. Only alphanumeric characters and underscores are allowed."
154 )));
155 }
156
157 let sql = format!("PRAGMA table_info({table_name})");
158 let rows = sqlx::query(&sql)
159 .fetch_all(&self.pool)
160 .await
161 .map_err(|e| SynapticError::Tool(format!("DescribeTable error: {e}")))?;
162
163 if rows.is_empty() {
164 return Err(SynapticError::Tool(format!(
165 "Table '{table_name}' does not exist."
166 )));
167 }
168
169 let columns: Vec<Value> = rows
170 .iter()
171 .map(|r| {
172 json!({
173 "cid": r.get::<i64, _>("cid"),
174 "name": r.get::<String, _>("name"),
175 "type": r.get::<String, _>("type"),
176 "not_null": r.get::<bool, _>("notnull"),
177 "primary_key": r.get::<i64, _>("pk") > 0,
178 })
179 })
180 .collect();
181
182 Ok(json!({
183 "table": table_name,
184 "columns": columns,
185 }))
186 }
187}
188
189pub struct ExecuteQueryTool {
195 pool: SqlitePool,
196}
197
198#[async_trait]
199impl Tool for ExecuteQueryTool {
200 fn name(&self) -> &'static str {
201 "sql_execute_query"
202 }
203
204 fn description(&self) -> &'static str {
205 "Execute a read-only SQL SELECT query and return results as JSON. \
206 Only SELECT statements are allowed for safety. \
207 Returns an array of objects, one per row."
208 }
209
210 fn parameters(&self) -> Option<Value> {
211 Some(json!({
212 "type": "object",
213 "properties": {
214 "query": {
215 "type": "string",
216 "description": "A SQL SELECT query to execute. Must start with SELECT."
217 }
218 },
219 "required": ["query"]
220 }))
221 }
222
223 async fn call(&self, args: Value) -> Result<Value, SynapticError> {
224 let query = args["query"]
225 .as_str()
226 .ok_or_else(|| SynapticError::Tool("missing 'query' parameter".to_string()))?;
227
228 let trimmed = query.trim_start().to_uppercase();
230 if !trimmed.starts_with("SELECT") {
231 return Err(SynapticError::Tool(
232 "Only SELECT queries are allowed for safety.".to_string(),
233 ));
234 }
235
236 let rows = sqlx::query(query)
237 .fetch_all(&self.pool)
238 .await
239 .map_err(|e| SynapticError::Tool(format!("Query execution error: {e}")))?;
240
241 let results: Vec<Value> = rows
242 .iter()
243 .map(|row| {
244 let mut obj = serde_json::Map::new();
245 for (i, col) in row.columns().iter().enumerate() {
246 let name = col.name().to_string();
247 let type_name = col.type_info().name();
248 let val: Value = match type_name {
249 "INTEGER" | "INT" | "BIGINT" => {
250 if let Ok(v) = row.try_get::<i64, _>(i) {
251 json!(v)
252 } else {
253 Value::Null
254 }
255 }
256 "REAL" | "FLOAT" | "DOUBLE" => {
257 if let Ok(v) = row.try_get::<f64, _>(i) {
258 json!(v)
259 } else {
260 Value::Null
261 }
262 }
263 "BOOLEAN" | "BOOL" => {
264 if let Ok(v) = row.try_get::<bool, _>(i) {
265 json!(v)
266 } else {
267 Value::Null
268 }
269 }
270 _ => {
271 if let Ok(v) = row.try_get::<String, _>(i) {
273 json!(v)
274 } else {
275 Value::Null
276 }
277 }
278 };
279 obj.insert(name, val);
280 }
281 Value::Object(obj)
282 })
283 .collect();
284
285 Ok(json!({
286 "rows": results,
287 "row_count": results.len(),
288 }))
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use sqlx::sqlite::SqlitePoolOptions;
296
297 async fn test_pool() -> SqlitePool {
298 SqlitePoolOptions::new()
299 .connect("sqlite::memory:")
300 .await
301 .expect("in-memory SQLite")
302 }
303
304 #[test]
305 fn is_safe_identifier_allows_valid() {
306 assert!(is_safe_identifier("users"));
307 assert!(is_safe_identifier("my_table_123"));
308 assert!(is_safe_identifier("ABC"));
309 }
310
311 #[test]
312 fn is_safe_identifier_blocks_injection() {
313 assert!(!is_safe_identifier("users; DROP TABLE users--"));
314 assert!(!is_safe_identifier("users--"));
315 assert!(!is_safe_identifier(""));
316 assert!(!is_safe_identifier("tab le"));
317 }
318
319 #[tokio::test]
320 async fn list_tables_empty_db() {
321 let pool = test_pool().await;
322 let tool = ListTablesTool { pool };
323 let result = tool.call(json!({})).await.unwrap();
324 assert_eq!(result["tables"], json!([]));
325 }
326
327 #[tokio::test]
328 async fn list_tables_with_data() {
329 let pool = test_pool().await;
330 sqlx::query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
331 .execute(&pool)
332 .await
333 .unwrap();
334 let tool = ListTablesTool { pool: pool.clone() };
335 let result = tool.call(json!({})).await.unwrap();
336 let tables = result["tables"].as_array().unwrap();
337 assert_eq!(tables.len(), 1);
338 assert_eq!(tables[0], "users");
339 }
340
341 #[tokio::test]
342 async fn describe_table() {
343 let pool = test_pool().await;
344 sqlx::query(
345 "CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT NOT NULL, price REAL)",
346 )
347 .execute(&pool)
348 .await
349 .unwrap();
350 let tool = DescribeTableTool { pool: pool.clone() };
351 let result = tool.call(json!({"table_name": "products"})).await.unwrap();
352 let cols = result["columns"].as_array().unwrap();
353 assert_eq!(cols.len(), 3);
354 assert_eq!(cols[0]["name"], "id");
355 assert_eq!(cols[0]["primary_key"], true);
356 }
357
358 #[tokio::test]
359 async fn execute_select_query() {
360 let pool = test_pool().await;
361 sqlx::query("CREATE TABLE items (id INTEGER, label TEXT)")
362 .execute(&pool)
363 .await
364 .unwrap();
365 sqlx::query("INSERT INTO items VALUES (1, 'alpha'), (2, 'beta')")
366 .execute(&pool)
367 .await
368 .unwrap();
369 let tool = ExecuteQueryTool { pool: pool.clone() };
370 let result = tool
371 .call(json!({"query": "SELECT id, label FROM items ORDER BY id"}))
372 .await
373 .unwrap();
374 assert_eq!(result["row_count"], 2);
375 assert_eq!(result["rows"][0]["label"], "alpha");
376 }
377
378 #[tokio::test]
379 async fn execute_non_select_rejected() {
380 let pool = test_pool().await;
381 let tool = ExecuteQueryTool { pool };
382 let err = tool
383 .call(json!({"query": "DROP TABLE users"}))
384 .await
385 .unwrap_err();
386 assert!(matches!(err, SynapticError::Tool(_)));
387 }
388}