1use std::collections::HashMap;
7use std::fmt;
8use thiserror::Error;
9
10use crate::dataset::Dataset;
11use crate::error::DataError;
12use torsh_core::TensorElement;
13use torsh_tensor::Tensor;
14
15#[derive(Error, Debug)]
16pub enum DatabaseError {
17 #[error("Connection error: {0}")]
18 ConnectionError(String),
19 #[error("Query error: {0}")]
20 QueryError(String),
21 #[error("Type conversion error: {0}")]
22 TypeConversionError(String),
23 #[error("Configuration error: {0}")]
24 ConfigError(String),
25 #[error("Column not found: {0}")]
26 ColumnNotFound(String),
27}
28
29impl From<DatabaseError> for DataError {
30 fn from(err: DatabaseError) -> Self {
31 DataError::Other(err.to_string())
32 }
33}
34
35#[derive(Debug, Clone, PartialEq)]
37pub enum DatabaseBackend {
38 SQLite,
39 PostgreSQL,
40 MySQL,
41 Memory, }
43
44impl fmt::Display for DatabaseBackend {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 DatabaseBackend::SQLite => write!(f, "SQLite"),
48 DatabaseBackend::PostgreSQL => write!(f, "PostgreSQL"),
49 DatabaseBackend::MySQL => write!(f, "MySQL"),
50 DatabaseBackend::Memory => write!(f, "Memory"),
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub enum DatabaseValue {
58 Integer(i64),
59 Float(f64),
60 Text(String),
61 Blob(Vec<u8>),
62 Null,
63}
64
65impl DatabaseValue {
66 pub fn to_tensor_element<T: TensorElement>(&self) -> std::result::Result<T, DatabaseError> {
68 match self {
69 DatabaseValue::Integer(val) => T::from_f64(*val as f64).ok_or_else(|| {
70 DatabaseError::TypeConversionError(format!(
71 "Cannot convert integer {val} to target type"
72 ))
73 }),
74 DatabaseValue::Float(val) => T::from_f64(*val).ok_or_else(|| {
75 DatabaseError::TypeConversionError(format!(
76 "Cannot convert float {val} to target type"
77 ))
78 }),
79 DatabaseValue::Text(val) => {
80 if let Ok(num) = val.parse::<f64>() {
82 T::from_f64(num).ok_or_else(|| {
83 DatabaseError::TypeConversionError(format!(
84 "Cannot convert parsed number {num} to target type"
85 ))
86 })
87 } else {
88 Err(DatabaseError::TypeConversionError(format!(
89 "Cannot convert text '{val}' to numeric type"
90 )))
91 }
92 }
93 DatabaseValue::Null => T::from_f64(0.0).ok_or_else(|| {
94 DatabaseError::TypeConversionError("Cannot convert NULL to target type".to_string())
95 }),
96 DatabaseValue::Blob(_) => Err(DatabaseError::TypeConversionError(
97 "Cannot convert BLOB to numeric type".to_string(),
98 )),
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct DatabaseRow {
106 columns: HashMap<String, DatabaseValue>,
107}
108
109impl DatabaseRow {
110 pub fn new() -> Self {
112 Self {
113 columns: HashMap::new(),
114 }
115 }
116
117 pub fn add_column(&mut self, name: String, value: DatabaseValue) {
119 self.columns.insert(name, value);
120 }
121
122 pub fn get_column(&self, name: &str) -> Option<&DatabaseValue> {
124 self.columns.get(name)
125 }
126
127 pub fn column_names(&self) -> Vec<&String> {
129 self.columns.keys().collect()
130 }
131
132 pub fn column_to_tensor_element<T: TensorElement>(
134 &self,
135 column_name: &str,
136 ) -> std::result::Result<T, DatabaseError> {
137 let value = self
138 .get_column(column_name)
139 .ok_or_else(|| DatabaseError::ColumnNotFound(column_name.to_string()))?;
140 value.to_tensor_element()
141 }
142
143 pub fn columns_to_tensor<T: TensorElement>(
145 &self,
146 column_names: &[&str],
147 ) -> std::result::Result<Tensor<T>, DatabaseError> {
148 let mut values = Vec::with_capacity(column_names.len());
149
150 for &column_name in column_names {
151 let tensor_value = self.column_to_tensor_element::<T>(column_name)?;
152 values.push(tensor_value);
153 }
154
155 let shape = vec![values.len()];
156 Tensor::from_vec(values, &shape)
157 .map_err(|e| DatabaseError::TypeConversionError(e.to_string()))
158 }
159}
160
161impl Default for DatabaseRow {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct DatabaseConfig {
170 pub backend: DatabaseBackend,
171 pub host: Option<String>,
172 pub port: Option<u16>,
173 pub database: String,
174 pub username: Option<String>,
175 pub password: Option<String>,
176 pub connection_string: Option<String>,
177}
178
179impl DatabaseConfig {
180 pub fn new(backend: DatabaseBackend, database: String) -> Self {
182 Self {
183 backend,
184 host: None,
185 port: None,
186 database,
187 username: None,
188 password: None,
189 connection_string: None,
190 }
191 }
192
193 pub fn with_host_port(mut self, host: String, port: u16) -> Self {
195 self.host = Some(host);
196 self.port = Some(port);
197 self
198 }
199
200 pub fn with_credentials(mut self, username: String, password: String) -> Self {
202 self.username = Some(username);
203 self.password = Some(password);
204 self
205 }
206
207 pub fn with_connection_string(mut self, connection_string: String) -> Self {
209 self.connection_string = Some(connection_string);
210 self
211 }
212
213 pub fn build_connection_string(&self) -> String {
215 if let Some(ref custom) = self.connection_string {
216 return custom.clone();
217 }
218
219 match self.backend {
220 DatabaseBackend::SQLite => {
221 format!("sqlite:{}", self.database)
222 }
223 DatabaseBackend::PostgreSQL => {
224 let host = self.host.as_deref().unwrap_or("localhost");
225 let port = self.port.unwrap_or(5432);
226 let username = self.username.as_deref().unwrap_or("postgres");
227 let password = self.password.as_deref().unwrap_or("");
228 format!(
229 "postgresql://{}:{}@{}:{}/{}",
230 username, password, host, port, self.database
231 )
232 }
233 DatabaseBackend::MySQL => {
234 let host = self.host.as_deref().unwrap_or("localhost");
235 let port = self.port.unwrap_or(3306);
236 let username = self.username.as_deref().unwrap_or("root");
237 let password = self.password.as_deref().unwrap_or("");
238 format!(
239 "mysql://{}:{}@{}:{}/{}",
240 username, password, host, port, self.database
241 )
242 }
243 DatabaseBackend::Memory => ":memory:".to_string(),
244 }
245 }
246}
247
248pub trait DatabaseConnection: Send + Sync {
250 fn execute_query(
252 &mut self,
253 query: &str,
254 ) -> std::result::Result<Vec<DatabaseRow>, DatabaseError>;
255
256 fn get_table_names(&mut self) -> std::result::Result<Vec<String>, DatabaseError>;
258
259 fn get_column_names(
261 &mut self,
262 table_name: &str,
263 ) -> std::result::Result<Vec<String>, DatabaseError>;
264
265 fn count_rows(&mut self, table_name: &str) -> std::result::Result<usize, DatabaseError>;
267
268 fn close(&mut self) -> std::result::Result<(), DatabaseError>;
270}
271
272pub struct MockDatabaseConnection {
274 _backend: DatabaseBackend,
275 tables: HashMap<String, Vec<DatabaseRow>>,
276}
277
278impl MockDatabaseConnection {
279 pub fn new(backend: DatabaseBackend) -> Self {
281 let mut tables = HashMap::new();
282
283 let mut sample_rows = Vec::new();
285 for i in 0..100 {
286 let mut row = DatabaseRow::new();
287 row.add_column("id".to_string(), DatabaseValue::Integer(i));
288 row.add_column("value".to_string(), DatabaseValue::Float(i as f64 * 1.5));
289 row.add_column("name".to_string(), DatabaseValue::Text(format!("item_{i}")));
290 sample_rows.push(row);
291 }
292 tables.insert("sample_table".to_string(), sample_rows);
293
294 Self {
295 _backend: backend,
296 tables,
297 }
298 }
299}
300
301impl DatabaseConnection for MockDatabaseConnection {
302 fn execute_query(
303 &mut self,
304 query: &str,
305 ) -> std::result::Result<Vec<DatabaseRow>, DatabaseError> {
306 let query_lower = query.to_lowercase();
308
309 if query_lower.contains("select") && query_lower.contains("from") {
310 if let Some(table_name) = query_lower.split("from").nth(1) {
312 let table_name = table_name.split_whitespace().next().unwrap_or("").trim();
313
314 if let Some(rows) = self.tables.get(table_name) {
315 if let Some(limit_part) = query_lower.split("limit").nth(1) {
317 if let Ok(limit) = limit_part.trim().parse::<usize>() {
318 return Ok(rows.iter().take(limit).cloned().collect());
319 }
320 }
321
322 return Ok(rows.clone());
323 }
324 }
325 }
326
327 Err(DatabaseError::QueryError(format!(
328 "Query not supported: {query}"
329 )))
330 }
331
332 fn get_table_names(&mut self) -> std::result::Result<Vec<String>, DatabaseError> {
333 Ok(self.tables.keys().cloned().collect())
334 }
335
336 fn get_column_names(
337 &mut self,
338 table_name: &str,
339 ) -> std::result::Result<Vec<String>, DatabaseError> {
340 if let Some(rows) = self.tables.get(table_name) {
341 if let Some(first_row) = rows.first() {
342 return Ok(first_row
343 .column_names()
344 .iter()
345 .map(|s| (*s).clone())
346 .collect());
347 }
348 }
349 Err(DatabaseError::QueryError(format!(
350 "Table not found: {table_name}"
351 )))
352 }
353
354 fn count_rows(&mut self, table_name: &str) -> std::result::Result<usize, DatabaseError> {
355 if let Some(rows) = self.tables.get(table_name) {
356 Ok(rows.len())
357 } else {
358 Err(DatabaseError::QueryError(format!(
359 "Table not found: {table_name}"
360 )))
361 }
362 }
363
364 fn close(&mut self) -> std::result::Result<(), DatabaseError> {
365 Ok(())
367 }
368}
369
370pub struct DatabaseDataset {
372 connection: Box<dyn DatabaseConnection>,
373 table_name: String,
374 columns: Vec<String>,
375 total_rows: usize,
376 _batch_size: usize,
377}
378
379impl DatabaseDataset {
380 pub fn new(
382 mut connection: Box<dyn DatabaseConnection>,
383 table_name: String,
384 columns: Option<Vec<String>>,
385 batch_size: Option<usize>,
386 ) -> std::result::Result<Self, DatabaseError> {
387 let columns = match columns {
389 Some(cols) => cols,
390 None => connection.get_column_names(&table_name)?,
391 };
392
393 let total_rows = connection.count_rows(&table_name)?;
394 let batch_size = batch_size.unwrap_or(1);
395
396 Ok(Self {
397 connection,
398 table_name,
399 columns,
400 total_rows,
401 _batch_size: batch_size,
402 })
403 }
404
405 pub fn columns(&self) -> &[String] {
407 &self.columns
408 }
409
410 pub fn table_name(&self) -> &str {
412 &self.table_name
413 }
414
415 pub fn read_batch(
417 &mut self,
418 start_idx: usize,
419 batch_size: usize,
420 ) -> std::result::Result<Vec<DatabaseRow>, DatabaseError> {
421 let query = format!(
422 "SELECT {} FROM {} LIMIT {} OFFSET {}",
423 self.columns.join(", "),
424 self.table_name,
425 batch_size,
426 start_idx
427 );
428
429 self.connection.execute_query(&query)
430 }
431
432 pub fn rows_to_tensors<T: TensorElement>(
434 &self,
435 rows: &[DatabaseRow],
436 ) -> std::result::Result<Vec<Tensor<T>>, DatabaseError> {
437 let mut column_tensors = Vec::new();
438
439 for column_name in &self.columns {
440 let mut column_values = Vec::with_capacity(rows.len());
441
442 for row in rows {
443 let value = row.column_to_tensor_element::<T>(column_name)?;
444 column_values.push(value);
445 }
446
447 let shape = vec![column_values.len()];
448 let tensor = Tensor::from_vec(column_values, &shape)
449 .map_err(|e| DatabaseError::TypeConversionError(e.to_string()))?;
450 column_tensors.push(tensor);
451 }
452
453 Ok(column_tensors)
454 }
455}
456
457impl Dataset for DatabaseDataset {
458 type Item = DatabaseRow;
459
460 fn len(&self) -> usize {
461 self.total_rows
462 }
463
464 fn get(&self, index: usize) -> torsh_core::error::Result<Self::Item> {
465 if index >= self.total_rows {
466 return Err(DataError::Other(format!(
467 "Index {} out of bounds for dataset of size {}",
468 index, self.total_rows
469 ))
470 .into());
471 }
472
473 let _query = format!(
475 "SELECT {} FROM {} LIMIT 1 OFFSET {}",
476 self.columns.join(", "),
477 self.table_name,
478 index
479 );
480
481 Err(DataError::Other(
484 "Individual row access not supported. Use batch operations instead.".to_string(),
485 )
486 .into())
487 }
488}
489
490pub struct DatabaseDatasetBuilder {
492 config: DatabaseConfig,
493 table_name: Option<String>,
494 columns: Option<Vec<String>>,
495 batch_size: Option<usize>,
496 query: Option<String>,
497}
498
499impl DatabaseDatasetBuilder {
500 pub fn new(config: DatabaseConfig) -> Self {
502 Self {
503 config,
504 table_name: None,
505 columns: None,
506 batch_size: None,
507 query: None,
508 }
509 }
510
511 pub fn table(mut self, table_name: String) -> Self {
513 self.table_name = Some(table_name);
514 self
515 }
516
517 pub fn columns(mut self, columns: Vec<String>) -> Self {
519 self.columns = Some(columns);
520 self
521 }
522
523 pub fn batch_size(mut self, batch_size: usize) -> Self {
525 self.batch_size = Some(batch_size);
526 self
527 }
528
529 pub fn query(mut self, query: String) -> Self {
531 self.query = Some(query);
532 self
533 }
534
535 pub fn build(self) -> std::result::Result<DatabaseDataset, DatabaseError> {
537 let connection: Box<dyn DatabaseConnection> = match self.config.backend {
538 DatabaseBackend::Memory => Box::new(MockDatabaseConnection::new(self.config.backend)),
539 _ => {
540 Box::new(MockDatabaseConnection::new(self.config.backend))
543 }
544 };
545
546 let table_name = self
547 .table_name
548 .ok_or_else(|| DatabaseError::ConfigError("Table name is required".to_string()))?;
549
550 DatabaseDataset::new(connection, table_name, self.columns, self.batch_size)
551 }
552}
553
554pub mod database_utils {
556 use super::*;
557
558 pub fn sqlite_config<P: AsRef<std::path::Path>>(database_path: P) -> DatabaseConfig {
560 DatabaseConfig::new(
561 DatabaseBackend::SQLite,
562 database_path.as_ref().to_string_lossy().to_string(),
563 )
564 }
565
566 pub fn postgresql_config(
568 host: &str,
569 port: u16,
570 database: &str,
571 username: &str,
572 password: &str,
573 ) -> DatabaseConfig {
574 DatabaseConfig::new(DatabaseBackend::PostgreSQL, database.to_string())
575 .with_host_port(host.to_string(), port)
576 .with_credentials(username.to_string(), password.to_string())
577 }
578
579 pub fn mysql_config(
581 host: &str,
582 port: u16,
583 database: &str,
584 username: &str,
585 password: &str,
586 ) -> DatabaseConfig {
587 DatabaseConfig::new(DatabaseBackend::MySQL, database.to_string())
588 .with_host_port(host.to_string(), port)
589 .with_credentials(username.to_string(), password.to_string())
590 }
591
592 pub fn memory_config() -> DatabaseConfig {
594 DatabaseConfig::new(DatabaseBackend::Memory, ":memory:".to_string())
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601
602 #[test]
603 fn test_database_value_conversion() {
604 let int_val = DatabaseValue::Integer(42);
605 let float_val = DatabaseValue::Float(3.14);
606 let text_val = DatabaseValue::Text("123.45".to_string());
607
608 assert!(int_val.to_tensor_element::<f32>().is_ok());
609 assert!(float_val.to_tensor_element::<f64>().is_ok());
610 assert!(text_val.to_tensor_element::<f32>().is_ok());
611 }
612
613 #[test]
614 fn test_database_row() {
615 let mut row = DatabaseRow::new();
616 row.add_column("id".to_string(), DatabaseValue::Integer(1));
617 row.add_column("value".to_string(), DatabaseValue::Float(2.5));
618
619 assert!(row.get_column("id").is_some());
620 assert!(row.get_column("nonexistent").is_none());
621 assert_eq!(row.column_names().len(), 2);
622 }
623
624 #[test]
625 fn test_database_config() {
626 let config = DatabaseConfig::new(DatabaseBackend::SQLite, "test.db".to_string());
627 assert_eq!(config.build_connection_string(), "sqlite:test.db");
628
629 let pg_config =
630 database_utils::postgresql_config("localhost", 5432, "testdb", "user", "pass");
631 assert!(pg_config
632 .build_connection_string()
633 .contains("postgresql://"));
634 }
635
636 #[test]
637 fn test_mock_connection() {
638 let mut conn = MockDatabaseConnection::new(DatabaseBackend::Memory);
639
640 let tables = conn.get_table_names().unwrap();
641 assert!(!tables.is_empty());
642
643 let columns = conn.get_column_names("sample_table").unwrap();
644 assert!(!columns.is_empty());
645
646 let count = conn.count_rows("sample_table").unwrap();
647 assert!(count > 0);
648 }
649
650 #[test]
651 fn test_database_dataset_builder() {
652 let config = database_utils::memory_config();
653 let builder = DatabaseDatasetBuilder::new(config)
654 .table("sample_table".to_string())
655 .columns(vec!["id".to_string(), "value".to_string()])
656 .batch_size(10);
657
658 let dataset = builder.build();
659 assert!(dataset.is_ok());
660 }
661}