Skip to main content

sage_runtime/tools/
database.rs

1//! RFC-0011: Database tool for Sage agents.
2//!
3//! Provides the `Database` tool with SQL query capabilities.
4//! Requires the `database` feature to be enabled.
5
6use crate::error::{SageError, SageResult};
7use crate::mock::{try_get_mock, MockResponse};
8
9#[cfg(feature = "database")]
10use sqlx::{any::AnyRow, AnyPool, Column, Row};
11
12/// A row returned from a database query.
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct DbRow {
15    /// Column names.
16    pub columns: Vec<String>,
17    /// Values as strings.
18    pub values: Vec<String>,
19}
20
21/// Database client for Sage agents.
22///
23/// Requires the `database` feature to be enabled.
24#[derive(Debug, Clone)]
25pub struct DatabaseClient {
26    #[cfg(feature = "database")]
27    pool: AnyPool,
28    #[cfg(not(feature = "database"))]
29    _marker: std::marker::PhantomData<()>,
30}
31
32impl DatabaseClient {
33    /// Create a new database client by connecting to the given URL.
34    ///
35    /// # Arguments
36    /// * `url` - Database connection URL (e.g., "postgres://localhost/db" or "sqlite::memory:")
37    #[cfg(feature = "database")]
38    pub async fn connect(url: &str) -> SageResult<Self> {
39        // Install default drivers
40        sqlx::any::install_default_drivers();
41
42        let pool = AnyPool::connect(url)
43            .await
44            .map_err(|e| SageError::Tool(format!("Database connection failed: {e}")))?;
45        Ok(Self { pool })
46    }
47
48    /// Create a new database client by connecting to the given URL.
49    #[cfg(not(feature = "database"))]
50    pub async fn connect(_url: &str) -> SageResult<Self> {
51        Err(SageError::Tool(
52            "Database support not enabled. Compile with the 'database' feature.".to_string(),
53        ))
54    }
55
56    /// Create a new database client from environment variables.
57    ///
58    /// Reads:
59    /// - `SAGE_DATABASE_URL`: Database connection URL (required)
60    #[cfg(feature = "database")]
61    pub async fn from_env() -> SageResult<Self> {
62        let url = std::env::var("SAGE_DATABASE_URL")
63            .map_err(|_| SageError::Tool("SAGE_DATABASE_URL environment variable not set".to_string()))?;
64        Self::connect(&url).await
65    }
66
67    /// Create a new database client from environment variables.
68    #[cfg(not(feature = "database"))]
69    pub async fn from_env() -> SageResult<Self> {
70        Err(SageError::Tool(
71            "Database support not enabled. Compile with the 'database' feature.".to_string(),
72        ))
73    }
74
75    /// Execute a SQL query and return the results.
76    ///
77    /// # Arguments
78    /// * `sql` - The SQL query to execute
79    ///
80    /// # Returns
81    /// A list of rows, each containing column names and values.
82    #[cfg(feature = "database")]
83    pub async fn query(&self, sql: String) -> SageResult<Vec<DbRow>> {
84        // Check for mock response first
85        if let Some(mock_response) = try_get_mock("Database", "query") {
86            return Self::apply_mock_vec(mock_response);
87        }
88
89        let rows: Vec<AnyRow> = sqlx::query(&sql)
90            .fetch_all(&self.pool)
91            .await
92            .map_err(|e| SageError::Tool(format!("Query failed: {e}")))?;
93
94        let result: Vec<DbRow> = rows
95            .iter()
96            .map(|row| {
97                let columns: Vec<String> = row.columns().iter().map(|c| c.name().to_string()).collect();
98                let values: Vec<String> = (0..row.columns().len())
99                    .map(|i| {
100                        // Try to get the value as different types
101                        if let Ok(v) = row.try_get::<String, _>(i) {
102                            v
103                        } else if let Ok(v) = row.try_get::<i64, _>(i) {
104                            v.to_string()
105                        } else if let Ok(v) = row.try_get::<i32, _>(i) {
106                            v.to_string()
107                        } else if let Ok(v) = row.try_get::<f64, _>(i) {
108                            v.to_string()
109                        } else if let Ok(v) = row.try_get::<bool, _>(i) {
110                            v.to_string()
111                        } else {
112                            // Fallback: try to get raw value as Option<String>
113                            row.try_get::<Option<String>, _>(i)
114                                .ok()
115                                .flatten()
116                                .unwrap_or_else(|| "null".to_string())
117                        }
118                    })
119                    .collect();
120                DbRow { columns, values }
121            })
122            .collect();
123
124        Ok(result)
125    }
126
127    /// Execute a SQL query and return the results.
128    #[cfg(not(feature = "database"))]
129    pub async fn query(&self, _sql: String) -> SageResult<Vec<DbRow>> {
130        // Check for mock response first (allows testing without database feature)
131        if let Some(mock_response) = try_get_mock("Database", "query") {
132            return Self::apply_mock_vec(mock_response);
133        }
134
135        Err(SageError::Tool(
136            "Database support not enabled. Compile with the 'database' feature.".to_string(),
137        ))
138    }
139
140    /// Execute a SQL statement (INSERT, UPDATE, DELETE) and return affected row count.
141    ///
142    /// # Arguments
143    /// * `sql` - The SQL statement to execute
144    ///
145    /// # Returns
146    /// Number of rows affected.
147    #[cfg(feature = "database")]
148    pub async fn execute(&self, sql: String) -> SageResult<i64> {
149        // Check for mock response first
150        if let Some(mock_response) = try_get_mock("Database", "execute") {
151            return Self::apply_mock_i64(mock_response);
152        }
153
154        let result = sqlx::query(&sql)
155            .execute(&self.pool)
156            .await
157            .map_err(|e| SageError::Tool(format!("Execute failed: {e}")))?;
158
159        Ok(result.rows_affected() as i64)
160    }
161
162    /// Execute a SQL statement and return affected row count.
163    #[cfg(not(feature = "database"))]
164    pub async fn execute(&self, _sql: String) -> SageResult<i64> {
165        // Check for mock response first (allows testing without database feature)
166        if let Some(mock_response) = try_get_mock("Database", "execute") {
167            return Self::apply_mock_i64(mock_response);
168        }
169
170        Err(SageError::Tool(
171            "Database support not enabled. Compile with the 'database' feature.".to_string(),
172        ))
173    }
174
175    /// Apply a mock response for Vec<DbRow>.
176    fn apply_mock_vec(mock_response: MockResponse) -> SageResult<Vec<DbRow>> {
177        match mock_response {
178            MockResponse::Value(v) => serde_json::from_value(v)
179                .map_err(|e| SageError::Tool(format!("mock deserialize: {e}"))),
180            MockResponse::Fail(msg) => Err(SageError::Tool(msg)),
181        }
182    }
183
184    /// Apply a mock response for i64.
185    fn apply_mock_i64(mock_response: MockResponse) -> SageResult<i64> {
186        match mock_response {
187            MockResponse::Value(v) => serde_json::from_value(v)
188                .map_err(|e| SageError::Tool(format!("mock deserialize: {e}"))),
189            MockResponse::Fail(msg) => Err(SageError::Tool(msg)),
190        }
191    }
192}
193
194#[cfg(all(test, feature = "database"))]
195mod tests {
196    use super::*;
197
198    #[tokio::test]
199    async fn database_connect_sqlite() {
200        // Use shared cache mode for in-memory database
201        let client = DatabaseClient::connect("sqlite:file::memory:?mode=memory&cache=shared").await.unwrap();
202        drop(client);
203    }
204
205    #[tokio::test]
206    async fn database_execute_and_query() {
207        // Use a temporary file-based database for this test to avoid pool issues
208        let temp_dir = tempfile::tempdir().unwrap();
209        let db_path = temp_dir.path().join("test.db");
210        // Create the file first
211        std::fs::write(&db_path, "").unwrap();
212        let url = format!("sqlite:{}?mode=rwc", db_path.display());
213
214        let client = DatabaseClient::connect(&url).await.unwrap();
215
216        // Create a table
217        client
218            .execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)".to_string())
219            .await
220            .unwrap();
221
222        // Insert data
223        let affected = client
224            .execute("INSERT INTO test (id, name) VALUES (1, 'Alice'), (2, 'Bob')".to_string())
225            .await
226            .unwrap();
227        assert_eq!(affected, 2);
228
229        // Query data
230        let rows = client
231            .query("SELECT id, name FROM test ORDER BY id".to_string())
232            .await
233            .unwrap();
234        assert_eq!(rows.len(), 2);
235        assert_eq!(rows[0].columns, vec!["id", "name"]);
236        assert_eq!(rows[0].values, vec!["1", "Alice"]);
237        assert_eq!(rows[1].values, vec!["2", "Bob"]);
238    }
239
240    #[tokio::test]
241    async fn database_query_select_one() {
242        let client = DatabaseClient::connect("sqlite:file::memory:?mode=memory&cache=shared").await.unwrap();
243        let rows = client.query("SELECT 1 as value".to_string()).await.unwrap();
244        assert_eq!(rows.len(), 1);
245        assert_eq!(rows[0].columns, vec!["value"]);
246        assert_eq!(rows[0].values, vec!["1"]);
247    }
248}