torch_web/
database.rs

1//! Database integration with connection pooling and query builder
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::{Request, Response, middleware::Middleware};
7
8#[cfg(feature = "database")]
9use {
10    sqlx::{Pool, Postgres, Row, Column},
11    serde_json::Value,
12};
13
14/// Database connection pool manager
15pub struct DatabasePool {
16    #[cfg(feature = "database")]
17    pool: Pool<Postgres>,
18    #[cfg(not(feature = "database"))]
19    _phantom: std::marker::PhantomData<()>,
20}
21
22impl DatabasePool {
23    /// Create a new database pool
24    #[cfg(feature = "database")]
25    pub async fn new(database_url: &str) -> Result<Self, sqlx::Error> {
26        let pool = sqlx::postgres::PgPoolOptions::new()
27            .max_connections(20)
28            .connect(database_url)
29            .await?;
30        
31        Ok(Self { pool })
32    }
33
34    #[cfg(not(feature = "database"))]
35    pub async fn new(_database_url: &str) -> Result<Self, Box<dyn std::error::Error>> {
36        Ok(Self {
37            _phantom: std::marker::PhantomData,
38        })
39    }
40
41    /// Execute a query and return results as JSON
42    #[cfg(feature = "database")]
43    pub async fn query_json(&self, query: &str, params: &[&str]) -> Result<Vec<Value>, sqlx::Error> {
44        let mut query_builder = sqlx::query(query);
45        
46        for param in params {
47            query_builder = query_builder.bind(param);
48        }
49        
50        let rows = query_builder.fetch_all(&self.pool).await?;
51        let mut results = Vec::new();
52        
53        for row in rows {
54            let mut json_row = serde_json::Map::new();
55            
56            for (i, column) in row.columns().iter().enumerate() {
57                let column_name = column.name();
58                let value: Option<String> = row.try_get(i).ok();
59                json_row.insert(
60                    column_name.to_string(),
61                    value.map(Value::String).unwrap_or(Value::Null),
62                );
63            }
64            
65            results.push(Value::Object(json_row));
66        }
67        
68        Ok(results)
69    }
70
71    #[cfg(not(feature = "database"))]
72    pub async fn query_json(&self, _query: &str, _params: &[&str]) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error>> {
73        Err("Database feature not enabled".into())
74    }
75
76    /// Execute a query and return the number of affected rows
77    #[cfg(feature = "database")]
78    pub async fn execute(&self, query: &str, params: &[&str]) -> Result<u64, sqlx::Error> {
79        let mut query_builder = sqlx::query(query);
80        
81        for param in params {
82            query_builder = query_builder.bind(param);
83        }
84        
85        let result = query_builder.execute(&self.pool).await?;
86        Ok(result.rows_affected())
87    }
88
89    #[cfg(not(feature = "database"))]
90    pub async fn execute(&self, _query: &str, _params: &[&str]) -> Result<u64, Box<dyn std::error::Error>> {
91        Err("Database feature not enabled".into())
92    }
93}
94
95/// Simple query builder for common operations
96pub struct QueryBuilder {
97    table: String,
98    select_fields: Vec<String>,
99    where_conditions: Vec<String>,
100    order_by: Vec<String>,
101    limit_value: Option<u64>,
102    offset_value: Option<u64>,
103}
104
105impl QueryBuilder {
106    pub fn new(table: &str) -> Self {
107        Self {
108            table: table.to_string(),
109            select_fields: vec!["*".to_string()],
110            where_conditions: Vec::new(),
111            order_by: Vec::new(),
112            limit_value: None,
113            offset_value: None,
114        }
115    }
116
117    pub fn select(mut self, fields: &[&str]) -> Self {
118        self.select_fields = fields.iter().map(|s| s.to_string()).collect();
119        self
120    }
121
122    pub fn where_eq(mut self, field: &str, value: &str) -> Self {
123        self.where_conditions.push(format!("{} = '{}'", field, value));
124        self
125    }
126
127    pub fn where_like(mut self, field: &str, pattern: &str) -> Self {
128        self.where_conditions.push(format!("{} LIKE '{}'", field, pattern));
129        self
130    }
131
132    pub fn order_by(mut self, field: &str, direction: &str) -> Self {
133        self.order_by.push(format!("{} {}", field, direction));
134        self
135    }
136
137    pub fn limit(mut self, limit: u64) -> Self {
138        self.limit_value = Some(limit);
139        self
140    }
141
142    pub fn offset(mut self, offset: u64) -> Self {
143        self.offset_value = Some(offset);
144        self
145    }
146
147    pub fn build_select(&self) -> String {
148        let mut query = format!("SELECT {} FROM {}", self.select_fields.join(", "), self.table);
149        
150        if !self.where_conditions.is_empty() {
151            query.push_str(&format!(" WHERE {}", self.where_conditions.join(" AND ")));
152        }
153        
154        if !self.order_by.is_empty() {
155            query.push_str(&format!(" ORDER BY {}", self.order_by.join(", ")));
156        }
157        
158        if let Some(limit) = self.limit_value {
159            query.push_str(&format!(" LIMIT {}", limit));
160        }
161        
162        if let Some(offset) = self.offset_value {
163            query.push_str(&format!(" OFFSET {}", offset));
164        }
165        
166        query
167    }
168
169    pub fn build_insert(&self, data: &HashMap<String, String>) -> String {
170        let fields: Vec<String> = data.keys().cloned().collect();
171        let values: Vec<String> = data.values().map(|v| format!("'{}'", v)).collect();
172        
173        format!(
174            "INSERT INTO {} ({}) VALUES ({})",
175            self.table,
176            fields.join(", "),
177            values.join(", ")
178        )
179    }
180
181    pub fn build_update(&self, data: &HashMap<String, String>) -> String {
182        let updates: Vec<String> = data
183            .iter()
184            .map(|(k, v)| format!("{} = '{}'", k, v))
185            .collect();
186        
187        let mut query = format!("UPDATE {} SET {}", self.table, updates.join(", "));
188        
189        if !self.where_conditions.is_empty() {
190            query.push_str(&format!(" WHERE {}", self.where_conditions.join(" AND ")));
191        }
192        
193        query
194    }
195
196    pub fn build_delete(&self) -> String {
197        let mut query = format!("DELETE FROM {}", self.table);
198        
199        if !self.where_conditions.is_empty() {
200            query.push_str(&format!(" WHERE {}", self.where_conditions.join(" AND ")));
201        }
202        
203        query
204    }
205}
206
207/// Database middleware for automatic connection injection
208pub struct DatabaseMiddleware {
209    pool: Arc<DatabasePool>,
210}
211
212impl DatabaseMiddleware {
213    pub fn new(pool: Arc<DatabasePool>) -> Self {
214        Self { pool }
215    }
216}
217
218impl Middleware for DatabaseMiddleware {
219    fn call(
220        &self,
221        mut req: Request,
222        next: Box<dyn Fn(Request) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> + Send + Sync>,
223    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + 'static>> {
224        let pool = self.pool.clone();
225        Box::pin(async move {
226            // Inject the database pool into the request extensions
227            req.insert_extension(pool);
228            next(req).await
229        })
230    }
231}
232
233/// Extension trait to add database access to Request
234pub trait RequestDatabaseExt {
235    /// Get the database pool from the request context
236    #[cfg(feature = "database")]
237    fn db_pool(&self) -> Option<Arc<DatabasePool>>;
238
239    #[cfg(not(feature = "database"))]
240    fn db_pool(&self) -> Option<()>;
241}
242
243impl RequestDatabaseExt for crate::Request {
244    #[cfg(feature = "database")]
245    fn db_pool(&self) -> Option<Arc<DatabasePool>> {
246        self.get_extension::<Arc<DatabasePool>>().cloned()
247    }
248
249    #[cfg(not(feature = "database"))]
250    fn db_pool(&self) -> Option<()> {
251        None
252    }
253}
254
255/// Migration runner for database schema management
256pub struct MigrationRunner {
257    #[cfg(feature = "database")]
258    #[allow(dead_code)]
259    pool: Arc<DatabasePool>,
260    #[allow(dead_code)]
261    migrations_dir: String,
262    #[cfg(not(feature = "database"))]
263    _phantom: std::marker::PhantomData<()>,
264}
265
266impl MigrationRunner {
267    pub fn new(_pool: Arc<DatabasePool>, migrations_dir: &str) -> Self {
268        Self {
269            #[cfg(feature = "database")]
270            pool: _pool,
271            migrations_dir: migrations_dir.to_string(),
272            #[cfg(not(feature = "database"))]
273            _phantom: std::marker::PhantomData,
274        }
275    }
276
277    /// Run pending migrations
278    #[cfg(feature = "database")]
279    pub async fn run_migrations(&self) -> Result<(), Box<dyn std::error::Error>> {
280        println!("Migration system initialized for directory: {}", self.migrations_dir);
281
282        // In a production implementation, this would:
283        // 1. Create migrations table
284        // 2. Read migration files from directory
285        // 3. Execute pending migrations in order
286        // 4. Record completed migrations
287
288        // For now, we'll just log that migrations would run
289        println!("Migration system ready - would execute SQL files from {}", self.migrations_dir);
290        Ok(())
291    }
292
293    #[cfg(not(feature = "database"))]
294    pub async fn run_migrations(&self) -> Result<(), Box<dyn std::error::Error>> {
295        Err("Database feature not enabled".into())
296    }
297}
298
299/// Database health check
300pub async fn database_health_check(pool: &DatabasePool) -> Response {
301    #[cfg(feature = "database")]
302    {
303        match pool.query_json("SELECT 1 as health_check", &[]).await {
304            Ok(_) => Response::ok().json(&serde_json::json!({
305                "database": "healthy",
306                "timestamp": chrono::Utc::now().to_rfc3339()
307            })).unwrap_or_else(|_| Response::ok().body("healthy")),
308            Err(e) => Response::with_status(http::StatusCode::SERVICE_UNAVAILABLE)
309                .json(&serde_json::json!({
310                    "database": "unhealthy",
311                    "error": e.to_string(),
312                    "timestamp": chrono::Utc::now().to_rfc3339()
313                })).unwrap_or_else(|_| Response::with_status(http::StatusCode::SERVICE_UNAVAILABLE).body("unhealthy"))
314        }
315    }
316    
317    #[cfg(not(feature = "database"))]
318    {
319        let _ = pool; // Suppress unused variable warning
320        Response::ok().body("Database feature not enabled")
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn test_query_builder_select() {
330        let query = QueryBuilder::new("users")
331            .select(&["id", "name", "email"])
332            .where_eq("active", "true")
333            .order_by("created_at", "DESC")
334            .limit(10)
335            .build_select();
336        
337        assert!(query.contains("SELECT id, name, email FROM users"));
338        assert!(query.contains("WHERE active = 'true'"));
339        assert!(query.contains("ORDER BY created_at DESC"));
340        assert!(query.contains("LIMIT 10"));
341    }
342
343    #[test]
344    fn test_query_builder_insert() {
345        let mut data = HashMap::new();
346        data.insert("name".to_string(), "John Doe".to_string());
347        data.insert("email".to_string(), "john@example.com".to_string());
348        
349        let query = QueryBuilder::new("users").build_insert(&data);
350        assert!(query.contains("INSERT INTO users"));
351        assert!(query.contains("name"));
352        assert!(query.contains("email"));
353    }
354
355    #[test]
356    fn test_query_builder_update() {
357        let mut data = HashMap::new();
358        data.insert("name".to_string(), "Jane Doe".to_string());
359        
360        let query = QueryBuilder::new("users")
361            .where_eq("id", "1")
362            .build_update(&data);
363        
364        assert!(query.contains("UPDATE users SET"));
365        assert!(query.contains("name = 'Jane Doe'"));
366        assert!(query.contains("WHERE id = '1'"));
367    }
368
369    #[test]
370    fn test_query_builder_delete() {
371        let query = QueryBuilder::new("users")
372            .where_eq("id", "1")
373            .build_delete();
374        
375        assert!(query.contains("DELETE FROM users"));
376        assert!(query.contains("WHERE id = '1'"));
377    }
378}