sage_runtime/tools/
database.rs1use crate::error::{SageError, SageResult};
7use crate::mock::{try_get_mock, MockResponse};
8
9#[cfg(feature = "database")]
10use sqlx::{any::AnyRow, AnyPool, Column, Row};
11
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
14pub struct DbRow {
15 pub columns: Vec<String>,
17 pub values: Vec<String>,
19}
20
21#[derive(Debug, Clone)]
25pub struct DatabaseClient {
26 #[cfg(feature = "database")]
27 pool: AnyPool,
28 #[cfg(not(feature = "database"))]
29 _marker: std::marker::PhantomData<()>,
30}
31
32impl DatabaseClient {
33 #[cfg(feature = "database")]
38 pub async fn connect(url: &str) -> SageResult<Self> {
39 sqlx::any::install_default_drivers();
41
42 let pool = AnyPool::connect(url)
43 .await
44 .map_err(|e| SageError::Tool(format!("Database connection failed: {e}")))?;
45 Ok(Self { pool })
46 }
47
48 #[cfg(not(feature = "database"))]
50 pub async fn connect(_url: &str) -> SageResult<Self> {
51 Err(SageError::Tool(
52 "Database support not enabled. Compile with the 'database' feature.".to_string(),
53 ))
54 }
55
56 #[cfg(feature = "database")]
61 pub async fn from_env() -> SageResult<Self> {
62 let url = std::env::var("SAGE_DATABASE_URL").map_err(|_| {
63 SageError::Tool("SAGE_DATABASE_URL environment variable not set".to_string())
64 })?;
65 Self::connect(&url).await
66 }
67
68 #[cfg(not(feature = "database"))]
70 pub async fn from_env() -> SageResult<Self> {
71 Err(SageError::Tool(
72 "Database support not enabled. Compile with the 'database' feature.".to_string(),
73 ))
74 }
75
76 #[cfg(feature = "database")]
84 pub async fn query(&self, sql: String) -> SageResult<Vec<DbRow>> {
85 if let Some(mock_response) = try_get_mock("Database", "query") {
87 return Self::apply_mock_vec(mock_response);
88 }
89
90 let rows: Vec<AnyRow> = sqlx::query(&sql)
91 .fetch_all(&self.pool)
92 .await
93 .map_err(|e| SageError::Tool(format!("Query failed: {e}")))?;
94
95 let result: Vec<DbRow> = rows
96 .iter()
97 .map(|row| {
98 let columns: Vec<String> =
99 row.columns().iter().map(|c| c.name().to_string()).collect();
100 let values: Vec<String> = (0..row.columns().len())
101 .map(|i| {
102 if let Ok(v) = row.try_get::<String, _>(i) {
104 v
105 } else if let Ok(v) = row.try_get::<i64, _>(i) {
106 v.to_string()
107 } else if let Ok(v) = row.try_get::<i32, _>(i) {
108 v.to_string()
109 } else if let Ok(v) = row.try_get::<f64, _>(i) {
110 v.to_string()
111 } else if let Ok(v) = row.try_get::<bool, _>(i) {
112 v.to_string()
113 } else {
114 row.try_get::<Option<String>, _>(i)
116 .ok()
117 .flatten()
118 .unwrap_or_else(|| "null".to_string())
119 }
120 })
121 .collect();
122 DbRow { columns, values }
123 })
124 .collect();
125
126 Ok(result)
127 }
128
129 #[cfg(not(feature = "database"))]
131 pub async fn query(&self, _sql: String) -> SageResult<Vec<DbRow>> {
132 if let Some(mock_response) = try_get_mock("Database", "query") {
134 return Self::apply_mock_vec(mock_response);
135 }
136
137 Err(SageError::Tool(
138 "Database support not enabled. Compile with the 'database' feature.".to_string(),
139 ))
140 }
141
142 #[cfg(feature = "database")]
150 pub async fn execute(&self, sql: String) -> SageResult<i64> {
151 if let Some(mock_response) = try_get_mock("Database", "execute") {
153 return Self::apply_mock_i64(mock_response);
154 }
155
156 let result = sqlx::query(&sql)
157 .execute(&self.pool)
158 .await
159 .map_err(|e| SageError::Tool(format!("Execute failed: {e}")))?;
160
161 Ok(result.rows_affected() as i64)
162 }
163
164 #[cfg(not(feature = "database"))]
166 pub async fn execute(&self, _sql: String) -> SageResult<i64> {
167 if let Some(mock_response) = try_get_mock("Database", "execute") {
169 return Self::apply_mock_i64(mock_response);
170 }
171
172 Err(SageError::Tool(
173 "Database support not enabled. Compile with the 'database' feature.".to_string(),
174 ))
175 }
176
177 fn apply_mock_vec(mock_response: MockResponse) -> SageResult<Vec<DbRow>> {
179 match mock_response {
180 MockResponse::Value(v) => serde_json::from_value(v)
181 .map_err(|e| SageError::Tool(format!("mock deserialize: {e}"))),
182 MockResponse::Fail(msg) => Err(SageError::Tool(msg)),
183 }
184 }
185
186 fn apply_mock_i64(mock_response: MockResponse) -> SageResult<i64> {
188 match mock_response {
189 MockResponse::Value(v) => serde_json::from_value(v)
190 .map_err(|e| SageError::Tool(format!("mock deserialize: {e}"))),
191 MockResponse::Fail(msg) => Err(SageError::Tool(msg)),
192 }
193 }
194}
195
196#[cfg(all(test, feature = "database"))]
197mod tests {
198 use super::*;
199
200 #[tokio::test]
201 async fn database_connect_sqlite() {
202 let client = DatabaseClient::connect("sqlite:file::memory:?mode=memory&cache=shared")
204 .await
205 .unwrap();
206 drop(client);
207 }
208
209 #[tokio::test]
210 async fn database_execute_and_query() {
211 let temp_dir = tempfile::tempdir().unwrap();
213 let db_path = temp_dir.path().join("test.db");
214 std::fs::write(&db_path, "").unwrap();
216 let url = format!("sqlite:{}?mode=rwc", db_path.display());
217
218 let client = DatabaseClient::connect(&url).await.unwrap();
219
220 client
222 .execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)".to_string())
223 .await
224 .unwrap();
225
226 let affected = client
228 .execute("INSERT INTO test (id, name) VALUES (1, 'Alice'), (2, 'Bob')".to_string())
229 .await
230 .unwrap();
231 assert_eq!(affected, 2);
232
233 let rows = client
235 .query("SELECT id, name FROM test ORDER BY id".to_string())
236 .await
237 .unwrap();
238 assert_eq!(rows.len(), 2);
239 assert_eq!(rows[0].columns, vec!["id", "name"]);
240 assert_eq!(rows[0].values, vec!["1", "Alice"]);
241 assert_eq!(rows[1].values, vec!["2", "Bob"]);
242 }
243
244 #[tokio::test]
245 async fn database_query_select_one() {
246 let client = DatabaseClient::connect("sqlite:file::memory:?mode=memory&cache=shared")
247 .await
248 .unwrap();
249 let rows = client.query("SELECT 1 as value".to_string()).await.unwrap();
250 assert_eq!(rows.len(), 1);
251 assert_eq!(rows[0].columns, vec!["value"]);
252 assert_eq!(rows[0].values, vec!["1"]);
253 }
254}