Skip to main content

sage_runtime/tools/
database.rs

1//! RFC-0011: Database tool for Sage agents.
2//!
3//! Provides the `Database` tool with SQL query capabilities.
4//! Requires the `database` feature to be enabled.
5
6use crate::error::{SageError, SageResult};
7use crate::mock::{try_get_mock, MockResponse};
8
9#[cfg(feature = "database")]
10use sqlx::{any::AnyRow, AnyPool, Column, Row};
11
12/// A row returned from a database query.
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct DbRow {
15    /// Column names.
16    pub columns: Vec<String>,
17    /// Values as strings.
18    pub values: Vec<String>,
19}
20
21/// Database client for Sage agents.
22///
23/// Requires the `database` feature to be enabled.
24#[derive(Debug, Clone)]
25pub struct DatabaseClient {
26    #[cfg(feature = "database")]
27    pool: AnyPool,
28    #[cfg(not(feature = "database"))]
29    _marker: std::marker::PhantomData<()>,
30}
31
32impl DatabaseClient {
33    /// Create a new database client by connecting to the given URL.
34    ///
35    /// # Arguments
36    /// * `url` - Database connection URL (e.g., "postgres://localhost/db" or "sqlite::memory:")
37    #[cfg(feature = "database")]
38    pub async fn connect(url: &str) -> SageResult<Self> {
39        // Install default drivers
40        sqlx::any::install_default_drivers();
41
42        let pool = AnyPool::connect(url)
43            .await
44            .map_err(|e| SageError::Tool(format!("Database connection failed: {e}")))?;
45        Ok(Self { pool })
46    }
47
48    /// Create a new database client by connecting to the given URL.
49    #[cfg(not(feature = "database"))]
50    pub async fn connect(_url: &str) -> SageResult<Self> {
51        Err(SageError::Tool(
52            "Database support not enabled. Compile with the 'database' feature.".to_string(),
53        ))
54    }
55
56    /// Create a new database client from environment variables.
57    ///
58    /// Reads:
59    /// - `SAGE_DATABASE_URL`: Database connection URL (required)
60    #[cfg(feature = "database")]
61    pub async fn from_env() -> SageResult<Self> {
62        let url = std::env::var("SAGE_DATABASE_URL").map_err(|_| {
63            SageError::Tool("SAGE_DATABASE_URL environment variable not set".to_string())
64        })?;
65        Self::connect(&url).await
66    }
67
68    /// Create a new database client from environment variables.
69    #[cfg(not(feature = "database"))]
70    pub async fn from_env() -> SageResult<Self> {
71        Err(SageError::Tool(
72            "Database support not enabled. Compile with the 'database' feature.".to_string(),
73        ))
74    }
75
76    /// Execute a SQL query and return the results.
77    ///
78    /// # Arguments
79    /// * `sql` - The SQL query to execute
80    ///
81    /// # Returns
82    /// A list of rows, each containing column names and values.
83    #[cfg(feature = "database")]
84    pub async fn query(&self, sql: String) -> SageResult<Vec<DbRow>> {
85        // Check for mock response first
86        if let Some(mock_response) = try_get_mock("Database", "query") {
87            return Self::apply_mock_vec(mock_response);
88        }
89
90        let rows: Vec<AnyRow> = sqlx::query(&sql)
91            .fetch_all(&self.pool)
92            .await
93            .map_err(|e| SageError::Tool(format!("Query failed: {e}")))?;
94
95        let result: Vec<DbRow> = rows
96            .iter()
97            .map(|row| {
98                let columns: Vec<String> =
99                    row.columns().iter().map(|c| c.name().to_string()).collect();
100                let values: Vec<String> = (0..row.columns().len())
101                    .map(|i| {
102                        // Try to get the value as different types
103                        if let Ok(v) = row.try_get::<String, _>(i) {
104                            v
105                        } else if let Ok(v) = row.try_get::<i64, _>(i) {
106                            v.to_string()
107                        } else if let Ok(v) = row.try_get::<i32, _>(i) {
108                            v.to_string()
109                        } else if let Ok(v) = row.try_get::<f64, _>(i) {
110                            v.to_string()
111                        } else if let Ok(v) = row.try_get::<bool, _>(i) {
112                            v.to_string()
113                        } else {
114                            // Fallback: try to get raw value as Option<String>
115                            row.try_get::<Option<String>, _>(i)
116                                .ok()
117                                .flatten()
118                                .unwrap_or_else(|| "null".to_string())
119                        }
120                    })
121                    .collect();
122                DbRow { columns, values }
123            })
124            .collect();
125
126        Ok(result)
127    }
128
129    /// Execute a SQL query and return the results.
130    #[cfg(not(feature = "database"))]
131    pub async fn query(&self, _sql: String) -> SageResult<Vec<DbRow>> {
132        // Check for mock response first (allows testing without database feature)
133        if let Some(mock_response) = try_get_mock("Database", "query") {
134            return Self::apply_mock_vec(mock_response);
135        }
136
137        Err(SageError::Tool(
138            "Database support not enabled. Compile with the 'database' feature.".to_string(),
139        ))
140    }
141
142    /// Execute a SQL statement (INSERT, UPDATE, DELETE) and return affected row count.
143    ///
144    /// # Arguments
145    /// * `sql` - The SQL statement to execute
146    ///
147    /// # Returns
148    /// Number of rows affected.
149    #[cfg(feature = "database")]
150    pub async fn execute(&self, sql: String) -> SageResult<i64> {
151        // Check for mock response first
152        if let Some(mock_response) = try_get_mock("Database", "execute") {
153            return Self::apply_mock_i64(mock_response);
154        }
155
156        let result = sqlx::query(&sql)
157            .execute(&self.pool)
158            .await
159            .map_err(|e| SageError::Tool(format!("Execute failed: {e}")))?;
160
161        Ok(result.rows_affected() as i64)
162    }
163
164    /// Execute a SQL statement and return affected row count.
165    #[cfg(not(feature = "database"))]
166    pub async fn execute(&self, _sql: String) -> SageResult<i64> {
167        // Check for mock response first (allows testing without database feature)
168        if let Some(mock_response) = try_get_mock("Database", "execute") {
169            return Self::apply_mock_i64(mock_response);
170        }
171
172        Err(SageError::Tool(
173            "Database support not enabled. Compile with the 'database' feature.".to_string(),
174        ))
175    }
176
177    /// Apply a mock response for Vec<DbRow>.
178    fn apply_mock_vec(mock_response: MockResponse) -> SageResult<Vec<DbRow>> {
179        match mock_response {
180            MockResponse::Value(v) => serde_json::from_value(v)
181                .map_err(|e| SageError::Tool(format!("mock deserialize: {e}"))),
182            MockResponse::Fail(msg) => Err(SageError::Tool(msg)),
183        }
184    }
185
186    /// Apply a mock response for i64.
187    fn apply_mock_i64(mock_response: MockResponse) -> SageResult<i64> {
188        match mock_response {
189            MockResponse::Value(v) => serde_json::from_value(v)
190                .map_err(|e| SageError::Tool(format!("mock deserialize: {e}"))),
191            MockResponse::Fail(msg) => Err(SageError::Tool(msg)),
192        }
193    }
194}
195
196#[cfg(all(test, feature = "database"))]
197mod tests {
198    use super::*;
199
200    #[tokio::test]
201    async fn database_connect_sqlite() {
202        // Use shared cache mode for in-memory database
203        let client = DatabaseClient::connect("sqlite:file::memory:?mode=memory&cache=shared")
204            .await
205            .unwrap();
206        drop(client);
207    }
208
209    #[tokio::test]
210    async fn database_execute_and_query() {
211        // Use a temporary file-based database for this test to avoid pool issues
212        let temp_dir = tempfile::tempdir().unwrap();
213        let db_path = temp_dir.path().join("test.db");
214        // Create the file first
215        std::fs::write(&db_path, "").unwrap();
216        let url = format!("sqlite:{}?mode=rwc", db_path.display());
217
218        let client = DatabaseClient::connect(&url).await.unwrap();
219
220        // Create a table
221        client
222            .execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)".to_string())
223            .await
224            .unwrap();
225
226        // Insert data
227        let affected = client
228            .execute("INSERT INTO test (id, name) VALUES (1, 'Alice'), (2, 'Bob')".to_string())
229            .await
230            .unwrap();
231        assert_eq!(affected, 2);
232
233        // Query data
234        let rows = client
235            .query("SELECT id, name FROM test ORDER BY id".to_string())
236            .await
237            .unwrap();
238        assert_eq!(rows.len(), 2);
239        assert_eq!(rows[0].columns, vec!["id", "name"]);
240        assert_eq!(rows[0].values, vec!["1", "Alice"]);
241        assert_eq!(rows[1].values, vec!["2", "Bob"]);
242    }
243
244    #[tokio::test]
245    async fn database_query_select_one() {
246        let client = DatabaseClient::connect("sqlite:file::memory:?mode=memory&cache=shared")
247            .await
248            .unwrap();
249        let rows = client.query("SELECT 1 as value".to_string()).await.unwrap();
250        assert_eq!(rows.len(), 1);
251        assert_eq!(rows[0].columns, vec!["value"]);
252        assert_eq!(rows[0].values, vec!["1"]);
253    }
254}