torsh_data/
database_integration.rs

1//! Database integration for loading data from various database backends
2//!
3//! This module provides a unified interface for loading data from different
4//! database systems such as SQLite, PostgreSQL, MySQL, etc.
5
6use std::collections::HashMap;
7use std::fmt;
8use thiserror::Error;
9
10use crate::dataset::Dataset;
11use crate::error::DataError;
12use torsh_core::TensorElement;
13use torsh_tensor::Tensor;
14
15#[derive(Error, Debug)]
16pub enum DatabaseError {
17    #[error("Connection error: {0}")]
18    ConnectionError(String),
19    #[error("Query error: {0}")]
20    QueryError(String),
21    #[error("Type conversion error: {0}")]
22    TypeConversionError(String),
23    #[error("Configuration error: {0}")]
24    ConfigError(String),
25    #[error("Column not found: {0}")]
26    ColumnNotFound(String),
27}
28
29impl From<DatabaseError> for DataError {
30    fn from(err: DatabaseError) -> Self {
31        DataError::Other(err.to_string())
32    }
33}
34
35/// Supported database backends
36#[derive(Debug, Clone, PartialEq)]
37pub enum DatabaseBackend {
38    SQLite,
39    PostgreSQL,
40    MySQL,
41    Memory, // In-memory database for testing
42}
43
44impl fmt::Display for DatabaseBackend {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        match self {
47            DatabaseBackend::SQLite => write!(f, "SQLite"),
48            DatabaseBackend::PostgreSQL => write!(f, "PostgreSQL"),
49            DatabaseBackend::MySQL => write!(f, "MySQL"),
50            DatabaseBackend::Memory => write!(f, "Memory"),
51        }
52    }
53}
54
55/// Database value types
56#[derive(Debug, Clone)]
57pub enum DatabaseValue {
58    Integer(i64),
59    Float(f64),
60    Text(String),
61    Blob(Vec<u8>),
62    Null,
63}
64
65impl DatabaseValue {
66    /// Convert to a tensor element if possible
67    pub fn to_tensor_element<T: TensorElement>(&self) -> std::result::Result<T, DatabaseError> {
68        match self {
69            DatabaseValue::Integer(val) => T::from_f64(*val as f64).ok_or_else(|| {
70                DatabaseError::TypeConversionError(format!(
71                    "Cannot convert integer {val} to target type"
72                ))
73            }),
74            DatabaseValue::Float(val) => T::from_f64(*val).ok_or_else(|| {
75                DatabaseError::TypeConversionError(format!(
76                    "Cannot convert float {val} to target type"
77                ))
78            }),
79            DatabaseValue::Text(val) => {
80                // Try to parse as number
81                if let Ok(num) = val.parse::<f64>() {
82                    T::from_f64(num).ok_or_else(|| {
83                        DatabaseError::TypeConversionError(format!(
84                            "Cannot convert parsed number {num} to target type"
85                        ))
86                    })
87                } else {
88                    Err(DatabaseError::TypeConversionError(format!(
89                        "Cannot convert text '{val}' to numeric type"
90                    )))
91                }
92            }
93            DatabaseValue::Null => T::from_f64(0.0).ok_or_else(|| {
94                DatabaseError::TypeConversionError("Cannot convert NULL to target type".to_string())
95            }),
96            DatabaseValue::Blob(_) => Err(DatabaseError::TypeConversionError(
97                "Cannot convert BLOB to numeric type".to_string(),
98            )),
99        }
100    }
101}
102
103/// A row of data from a database query result
104#[derive(Debug, Clone)]
105pub struct DatabaseRow {
106    columns: HashMap<String, DatabaseValue>,
107}
108
109impl DatabaseRow {
110    /// Create a new database row
111    pub fn new() -> Self {
112        Self {
113            columns: HashMap::new(),
114        }
115    }
116
117    /// Add a column value
118    pub fn add_column(&mut self, name: String, value: DatabaseValue) {
119        self.columns.insert(name, value);
120    }
121
122    /// Get a column value by name
123    pub fn get_column(&self, name: &str) -> Option<&DatabaseValue> {
124        self.columns.get(name)
125    }
126
127    /// Get all column names
128    pub fn column_names(&self) -> Vec<&String> {
129        self.columns.keys().collect()
130    }
131
132    /// Convert a column to a tensor element
133    pub fn column_to_tensor_element<T: TensorElement>(
134        &self,
135        column_name: &str,
136    ) -> std::result::Result<T, DatabaseError> {
137        let value = self
138            .get_column(column_name)
139            .ok_or_else(|| DatabaseError::ColumnNotFound(column_name.to_string()))?;
140        value.to_tensor_element()
141    }
142
143    /// Convert multiple columns to a tensor
144    pub fn columns_to_tensor<T: TensorElement>(
145        &self,
146        column_names: &[&str],
147    ) -> std::result::Result<Tensor<T>, DatabaseError> {
148        let mut values = Vec::with_capacity(column_names.len());
149
150        for &column_name in column_names {
151            let tensor_value = self.column_to_tensor_element::<T>(column_name)?;
152            values.push(tensor_value);
153        }
154
155        let shape = vec![values.len()];
156        Tensor::from_vec(values, &shape)
157            .map_err(|e| DatabaseError::TypeConversionError(e.to_string()))
158    }
159}
160
161impl Default for DatabaseRow {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167/// Configuration for database connections
168#[derive(Debug, Clone)]
169pub struct DatabaseConfig {
170    pub backend: DatabaseBackend,
171    pub host: Option<String>,
172    pub port: Option<u16>,
173    pub database: String,
174    pub username: Option<String>,
175    pub password: Option<String>,
176    pub connection_string: Option<String>,
177}
178
179impl DatabaseConfig {
180    /// Create a new database config
181    pub fn new(backend: DatabaseBackend, database: String) -> Self {
182        Self {
183            backend,
184            host: None,
185            port: None,
186            database,
187            username: None,
188            password: None,
189            connection_string: None,
190        }
191    }
192
193    /// Set host and port
194    pub fn with_host_port(mut self, host: String, port: u16) -> Self {
195        self.host = Some(host);
196        self.port = Some(port);
197        self
198    }
199
200    /// Set credentials
201    pub fn with_credentials(mut self, username: String, password: String) -> Self {
202        self.username = Some(username);
203        self.password = Some(password);
204        self
205    }
206
207    /// Set custom connection string
208    pub fn with_connection_string(mut self, connection_string: String) -> Self {
209        self.connection_string = Some(connection_string);
210        self
211    }
212
213    /// Build connection string based on backend
214    pub fn build_connection_string(&self) -> String {
215        if let Some(ref custom) = self.connection_string {
216            return custom.clone();
217        }
218
219        match self.backend {
220            DatabaseBackend::SQLite => {
221                format!("sqlite:{}", self.database)
222            }
223            DatabaseBackend::PostgreSQL => {
224                let host = self.host.as_deref().unwrap_or("localhost");
225                let port = self.port.unwrap_or(5432);
226                let username = self.username.as_deref().unwrap_or("postgres");
227                let password = self.password.as_deref().unwrap_or("");
228                format!(
229                    "postgresql://{}:{}@{}:{}/{}",
230                    username, password, host, port, self.database
231                )
232            }
233            DatabaseBackend::MySQL => {
234                let host = self.host.as_deref().unwrap_or("localhost");
235                let port = self.port.unwrap_or(3306);
236                let username = self.username.as_deref().unwrap_or("root");
237                let password = self.password.as_deref().unwrap_or("");
238                format!(
239                    "mysql://{}:{}@{}:{}/{}",
240                    username, password, host, port, self.database
241                )
242            }
243            DatabaseBackend::Memory => ":memory:".to_string(),
244        }
245    }
246}
247
248/// Trait for database connections
249pub trait DatabaseConnection: Send + Sync {
250    /// Execute a query and return the results
251    fn execute_query(
252        &mut self,
253        query: &str,
254    ) -> std::result::Result<Vec<DatabaseRow>, DatabaseError>;
255
256    /// Get table names
257    fn get_table_names(&mut self) -> std::result::Result<Vec<String>, DatabaseError>;
258
259    /// Get column names for a table
260    fn get_column_names(
261        &mut self,
262        table_name: &str,
263    ) -> std::result::Result<Vec<String>, DatabaseError>;
264
265    /// Count rows in a table
266    fn count_rows(&mut self, table_name: &str) -> std::result::Result<usize, DatabaseError>;
267
268    /// Close the connection
269    fn close(&mut self) -> std::result::Result<(), DatabaseError>;
270}
271
272/// Mock database connection for testing and demonstration
273pub struct MockDatabaseConnection {
274    _backend: DatabaseBackend,
275    tables: HashMap<String, Vec<DatabaseRow>>,
276}
277
278impl MockDatabaseConnection {
279    /// Create a new mock connection
280    pub fn new(backend: DatabaseBackend) -> Self {
281        let mut tables = HashMap::new();
282
283        // Create some sample data
284        let mut sample_rows = Vec::new();
285        for i in 0..100 {
286            let mut row = DatabaseRow::new();
287            row.add_column("id".to_string(), DatabaseValue::Integer(i));
288            row.add_column("value".to_string(), DatabaseValue::Float(i as f64 * 1.5));
289            row.add_column("name".to_string(), DatabaseValue::Text(format!("item_{i}")));
290            sample_rows.push(row);
291        }
292        tables.insert("sample_table".to_string(), sample_rows);
293
294        Self {
295            _backend: backend,
296            tables,
297        }
298    }
299}
300
301impl DatabaseConnection for MockDatabaseConnection {
302    fn execute_query(
303        &mut self,
304        query: &str,
305    ) -> std::result::Result<Vec<DatabaseRow>, DatabaseError> {
306        // Very simple query parsing for demo purposes
307        let query_lower = query.to_lowercase();
308
309        if query_lower.contains("select") && query_lower.contains("from") {
310            // Extract table name (very simplified)
311            if let Some(table_name) = query_lower.split("from").nth(1) {
312                let table_name = table_name.split_whitespace().next().unwrap_or("").trim();
313
314                if let Some(rows) = self.tables.get(table_name) {
315                    // Apply LIMIT if present
316                    if let Some(limit_part) = query_lower.split("limit").nth(1) {
317                        if let Ok(limit) = limit_part.trim().parse::<usize>() {
318                            return Ok(rows.iter().take(limit).cloned().collect());
319                        }
320                    }
321
322                    return Ok(rows.clone());
323                }
324            }
325        }
326
327        Err(DatabaseError::QueryError(format!(
328            "Query not supported: {query}"
329        )))
330    }
331
332    fn get_table_names(&mut self) -> std::result::Result<Vec<String>, DatabaseError> {
333        Ok(self.tables.keys().cloned().collect())
334    }
335
336    fn get_column_names(
337        &mut self,
338        table_name: &str,
339    ) -> std::result::Result<Vec<String>, DatabaseError> {
340        if let Some(rows) = self.tables.get(table_name) {
341            if let Some(first_row) = rows.first() {
342                return Ok(first_row
343                    .column_names()
344                    .iter()
345                    .map(|s| (*s).clone())
346                    .collect());
347            }
348        }
349        Err(DatabaseError::QueryError(format!(
350            "Table not found: {table_name}"
351        )))
352    }
353
354    fn count_rows(&mut self, table_name: &str) -> std::result::Result<usize, DatabaseError> {
355        if let Some(rows) = self.tables.get(table_name) {
356            Ok(rows.len())
357        } else {
358            Err(DatabaseError::QueryError(format!(
359                "Table not found: {table_name}"
360            )))
361        }
362    }
363
364    fn close(&mut self) -> std::result::Result<(), DatabaseError> {
365        // Nothing to do for mock connection
366        Ok(())
367    }
368}
369
370/// Dataset that loads data from a database table
371pub struct DatabaseDataset {
372    connection: Box<dyn DatabaseConnection>,
373    table_name: String,
374    columns: Vec<String>,
375    total_rows: usize,
376    _batch_size: usize,
377}
378
379impl DatabaseDataset {
380    /// Create a new database dataset
381    pub fn new(
382        mut connection: Box<dyn DatabaseConnection>,
383        table_name: String,
384        columns: Option<Vec<String>>,
385        batch_size: Option<usize>,
386    ) -> std::result::Result<Self, DatabaseError> {
387        // Get column names if not specified
388        let columns = match columns {
389            Some(cols) => cols,
390            None => connection.get_column_names(&table_name)?,
391        };
392
393        let total_rows = connection.count_rows(&table_name)?;
394        let batch_size = batch_size.unwrap_or(1);
395
396        Ok(Self {
397            connection,
398            table_name,
399            columns,
400            total_rows,
401            _batch_size: batch_size,
402        })
403    }
404
405    /// Get column names
406    pub fn columns(&self) -> &[String] {
407        &self.columns
408    }
409
410    /// Get table name
411    pub fn table_name(&self) -> &str {
412        &self.table_name
413    }
414
415    /// Read a batch of rows
416    pub fn read_batch(
417        &mut self,
418        start_idx: usize,
419        batch_size: usize,
420    ) -> std::result::Result<Vec<DatabaseRow>, DatabaseError> {
421        let query = format!(
422            "SELECT {} FROM {} LIMIT {} OFFSET {}",
423            self.columns.join(", "),
424            self.table_name,
425            batch_size,
426            start_idx
427        );
428
429        self.connection.execute_query(&query)
430    }
431
432    /// Convert rows to tensors
433    pub fn rows_to_tensors<T: TensorElement>(
434        &self,
435        rows: &[DatabaseRow],
436    ) -> std::result::Result<Vec<Tensor<T>>, DatabaseError> {
437        let mut column_tensors = Vec::new();
438
439        for column_name in &self.columns {
440            let mut column_values = Vec::with_capacity(rows.len());
441
442            for row in rows {
443                let value = row.column_to_tensor_element::<T>(column_name)?;
444                column_values.push(value);
445            }
446
447            let shape = vec![column_values.len()];
448            let tensor = Tensor::from_vec(column_values, &shape)
449                .map_err(|e| DatabaseError::TypeConversionError(e.to_string()))?;
450            column_tensors.push(tensor);
451        }
452
453        Ok(column_tensors)
454    }
455}
456
457impl Dataset for DatabaseDataset {
458    type Item = DatabaseRow;
459
460    fn len(&self) -> usize {
461        self.total_rows
462    }
463
464    fn get(&self, index: usize) -> torsh_core::error::Result<Self::Item> {
465        if index >= self.total_rows {
466            return Err(DataError::Other(format!(
467                "Index {} out of bounds for dataset of size {}",
468                index, self.total_rows
469            ))
470            .into());
471        }
472
473        // This is inefficient for individual row access but works for demonstration
474        let _query = format!(
475            "SELECT {} FROM {} LIMIT 1 OFFSET {}",
476            self.columns.join(", "),
477            self.table_name,
478            index
479        );
480
481        // Since we need &mut self but trait requires &self, we'll create a simple workaround
482        // In practice, you'd design this differently or use interior mutability
483        Err(DataError::Other(
484            "Individual row access not supported. Use batch operations instead.".to_string(),
485        )
486        .into())
487    }
488}
489
490/// Builder for creating database datasets
491pub struct DatabaseDatasetBuilder {
492    config: DatabaseConfig,
493    table_name: Option<String>,
494    columns: Option<Vec<String>>,
495    batch_size: Option<usize>,
496    query: Option<String>,
497}
498
499impl DatabaseDatasetBuilder {
500    /// Create a new builder
501    pub fn new(config: DatabaseConfig) -> Self {
502        Self {
503            config,
504            table_name: None,
505            columns: None,
506            batch_size: None,
507            query: None,
508        }
509    }
510
511    /// Set the table name
512    pub fn table(mut self, table_name: String) -> Self {
513        self.table_name = Some(table_name);
514        self
515    }
516
517    /// Set the columns to select
518    pub fn columns(mut self, columns: Vec<String>) -> Self {
519        self.columns = Some(columns);
520        self
521    }
522
523    /// Set the batch size
524    pub fn batch_size(mut self, batch_size: usize) -> Self {
525        self.batch_size = Some(batch_size);
526        self
527    }
528
529    /// Set a custom query
530    pub fn query(mut self, query: String) -> Self {
531        self.query = Some(query);
532        self
533    }
534
535    /// Build the database dataset
536    pub fn build(self) -> std::result::Result<DatabaseDataset, DatabaseError> {
537        let connection: Box<dyn DatabaseConnection> = match self.config.backend {
538            DatabaseBackend::Memory => Box::new(MockDatabaseConnection::new(self.config.backend)),
539            _ => {
540                // For now, use mock connection for all backends
541                // In a real implementation, you'd create actual database connections
542                Box::new(MockDatabaseConnection::new(self.config.backend))
543            }
544        };
545
546        let table_name = self
547            .table_name
548            .ok_or_else(|| DatabaseError::ConfigError("Table name is required".to_string()))?;
549
550        DatabaseDataset::new(connection, table_name, self.columns, self.batch_size)
551    }
552}
553
554/// Utility functions for database operations
555pub mod database_utils {
556    use super::*;
557
558    /// Create a SQLite database configuration
559    pub fn sqlite_config<P: AsRef<std::path::Path>>(database_path: P) -> DatabaseConfig {
560        DatabaseConfig::new(
561            DatabaseBackend::SQLite,
562            database_path.as_ref().to_string_lossy().to_string(),
563        )
564    }
565
566    /// Create a PostgreSQL database configuration
567    pub fn postgresql_config(
568        host: &str,
569        port: u16,
570        database: &str,
571        username: &str,
572        password: &str,
573    ) -> DatabaseConfig {
574        DatabaseConfig::new(DatabaseBackend::PostgreSQL, database.to_string())
575            .with_host_port(host.to_string(), port)
576            .with_credentials(username.to_string(), password.to_string())
577    }
578
579    /// Create a MySQL database configuration
580    pub fn mysql_config(
581        host: &str,
582        port: u16,
583        database: &str,
584        username: &str,
585        password: &str,
586    ) -> DatabaseConfig {
587        DatabaseConfig::new(DatabaseBackend::MySQL, database.to_string())
588            .with_host_port(host.to_string(), port)
589            .with_credentials(username.to_string(), password.to_string())
590    }
591
592    /// Create an in-memory database configuration for testing
593    pub fn memory_config() -> DatabaseConfig {
594        DatabaseConfig::new(DatabaseBackend::Memory, ":memory:".to_string())
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601
602    #[test]
603    fn test_database_value_conversion() {
604        let int_val = DatabaseValue::Integer(42);
605        let float_val = DatabaseValue::Float(3.14);
606        let text_val = DatabaseValue::Text("123.45".to_string());
607
608        assert!(int_val.to_tensor_element::<f32>().is_ok());
609        assert!(float_val.to_tensor_element::<f64>().is_ok());
610        assert!(text_val.to_tensor_element::<f32>().is_ok());
611    }
612
613    #[test]
614    fn test_database_row() {
615        let mut row = DatabaseRow::new();
616        row.add_column("id".to_string(), DatabaseValue::Integer(1));
617        row.add_column("value".to_string(), DatabaseValue::Float(2.5));
618
619        assert!(row.get_column("id").is_some());
620        assert!(row.get_column("nonexistent").is_none());
621        assert_eq!(row.column_names().len(), 2);
622    }
623
624    #[test]
625    fn test_database_config() {
626        let config = DatabaseConfig::new(DatabaseBackend::SQLite, "test.db".to_string());
627        assert_eq!(config.build_connection_string(), "sqlite:test.db");
628
629        let pg_config =
630            database_utils::postgresql_config("localhost", 5432, "testdb", "user", "pass");
631        assert!(pg_config
632            .build_connection_string()
633            .contains("postgresql://"));
634    }
635
636    #[test]
637    fn test_mock_connection() {
638        let mut conn = MockDatabaseConnection::new(DatabaseBackend::Memory);
639
640        let tables = conn.get_table_names().unwrap();
641        assert!(!tables.is_empty());
642
643        let columns = conn.get_column_names("sample_table").unwrap();
644        assert!(!columns.is_empty());
645
646        let count = conn.count_rows("sample_table").unwrap();
647        assert!(count > 0);
648    }
649
650    #[test]
651    fn test_database_dataset_builder() {
652        let config = database_utils::memory_config();
653        let builder = DatabaseDatasetBuilder::new(config)
654            .table("sample_table".to_string())
655            .columns(vec!["id".to_string(), "value".to_string()])
656            .batch_size(10);
657
658        let dataset = builder.build();
659        assert!(dataset.is_ok());
660    }
661}