Skip to main content

sig_runtime/
connectors.rs

1//! Data connectors for SQL databases and cloud storage
2//!
3//! Provides unified interface for loading data from various sources.
4
5use polars::prelude::*;
6use sig_types::{Result, SigcError};
7use std::collections::HashMap;
8
9/// Connector configuration
10#[derive(Debug, Clone)]
11pub enum ConnectorConfig {
12    /// PostgreSQL connection
13    Postgres {
14        host: String,
15        port: u16,
16        database: String,
17        user: String,
18        password: String,
19    },
20    /// Snowflake connection
21    Snowflake {
22        account: String,
23        warehouse: String,
24        database: String,
25        schema: String,
26        user: String,
27        password: String,
28    },
29    /// AWS S3
30    S3 {
31        bucket: String,
32        region: String,
33        access_key: Option<String>,
34        secret_key: Option<String>,
35    },
36    /// Google Cloud Storage
37    Gcs {
38        bucket: String,
39        project: String,
40        credentials_path: Option<String>,
41    },
42    /// Azure Blob Storage
43    Azure {
44        container: String,
45        account: String,
46        access_key: Option<String>,
47    },
48}
49
50/// Data connector trait
51pub trait Connector: Send + Sync {
52    /// Load data from a path/query
53    fn load(&self, path: &str) -> Result<DataFrame>;
54
55    /// Check if connector is available
56    fn is_available(&self) -> bool;
57
58    /// Get connector name
59    fn name(&self) -> &str;
60}
61
62/// SQL connector for databases
63pub struct SqlConnector {
64    config: ConnectorConfig,
65    name: String,
66}
67
68impl SqlConnector {
69    /// Create a PostgreSQL connector
70    pub fn postgres(host: &str, port: u16, database: &str, user: &str, password: &str) -> Self {
71        SqlConnector {
72            config: ConnectorConfig::Postgres {
73                host: host.to_string(),
74                port,
75                database: database.to_string(),
76                user: user.to_string(),
77                password: password.to_string(),
78            },
79            name: "postgres".to_string(),
80        }
81    }
82
83    /// Create a Snowflake connector
84    pub fn snowflake(
85        account: &str,
86        warehouse: &str,
87        database: &str,
88        schema: &str,
89        user: &str,
90        password: &str,
91    ) -> Self {
92        SqlConnector {
93            config: ConnectorConfig::Snowflake {
94                account: account.to_string(),
95                warehouse: warehouse.to_string(),
96                database: database.to_string(),
97                schema: schema.to_string(),
98                user: user.to_string(),
99                password: password.to_string(),
100            },
101            name: "snowflake".to_string(),
102        }
103    }
104
105    /// Build connection string
106    #[allow(dead_code)]
107    fn connection_string(&self) -> String {
108        match &self.config {
109            ConnectorConfig::Postgres { host, port, database, user, password } => {
110                format!("postgresql://{}:{}@{}:{}/{}", user, password, host, port, database)
111            }
112            ConnectorConfig::Snowflake { account, warehouse, database, schema, user, password } => {
113                format!(
114                    "snowflake://{}:{}@{}/{}/{}?warehouse={}",
115                    user, password, account, database, schema, warehouse
116                )
117            }
118            _ => String::new(),
119        }
120    }
121}
122
123impl Connector for SqlConnector {
124    fn load(&self, query: &str) -> Result<DataFrame> {
125        match &self.config {
126            ConnectorConfig::Postgres { host, port, database, user, password } => {
127                self.load_postgres(host, *port, database, user, password, query)
128            }
129            ConnectorConfig::Snowflake { .. } => {
130                // Snowflake requires their specific driver
131                Err(SigcError::Runtime(
132                    "Snowflake connector requires snowflake-connector. Use ODBC or REST API.".into()
133                ))
134            }
135            _ => Err(SigcError::Runtime("Invalid config for SQL connector".into())),
136        }
137    }
138
139    fn is_available(&self) -> bool {
140        match &self.config {
141            ConnectorConfig::Postgres { host, port, database, user, password } => {
142                let conn_str = format!(
143                    "host={} port={} dbname={} user={} password={}",
144                    host, port, database, user, password
145                );
146                postgres::Client::connect(&conn_str, postgres::NoTls).is_ok()
147            }
148            _ => false,
149        }
150    }
151
152    fn name(&self) -> &str {
153        &self.name
154    }
155}
156
157impl SqlConnector {
158    /// Load data from PostgreSQL
159    fn load_postgres(
160        &self,
161        host: &str,
162        port: u16,
163        database: &str,
164        user: &str,
165        password: &str,
166        query: &str,
167    ) -> Result<DataFrame> {
168        let conn_str = format!(
169            "host={} port={} dbname={} user={} password={}",
170            host, port, database, user, password
171        );
172
173        let mut client = postgres::Client::connect(&conn_str, postgres::NoTls)
174            .map_err(|e| SigcError::Runtime(format!("Failed to connect to Postgres: {}", e)))?;
175
176        let rows = client.query(query, &[])
177            .map_err(|e| SigcError::Runtime(format!("Query failed: {}", e)))?;
178
179        if rows.is_empty() {
180            return Err(SigcError::Runtime("Query returned no rows".into()));
181        }
182
183        // Get column info from first row
184        let columns = rows[0].columns();
185        let mut series_data: Vec<(String, Vec<f64>)> = Vec::new();
186        let mut string_data: Vec<(String, Vec<String>)> = Vec::new();
187
188        // Initialize columns
189        for col in columns {
190            let name = col.name().to_string();
191            let type_name = col.type_().name();
192
193            match type_name {
194                "float4" | "float8" | "numeric" | "int2" | "int4" | "int8" => {
195                    series_data.push((name, Vec::with_capacity(rows.len())));
196                }
197                "text" | "varchar" | "date" | "timestamp" | "timestamptz" => {
198                    string_data.push((name, Vec::with_capacity(rows.len())));
199                }
200                _ => {
201                    // Try as string
202                    string_data.push((name, Vec::with_capacity(rows.len())));
203                }
204            }
205        }
206
207        // Extract data from rows
208        for row in &rows {
209            let mut float_idx = 0;
210            let mut string_idx = 0;
211
212            for (i, col) in columns.iter().enumerate() {
213                let type_name = col.type_().name();
214
215                match type_name {
216                    "float4" => {
217                        let val: Option<f32> = row.get(i);
218                        series_data[float_idx].1.push(val.map(|v| v as f64).unwrap_or(f64::NAN));
219                        float_idx += 1;
220                    }
221                    "float8" | "numeric" => {
222                        let val: Option<f64> = row.get(i);
223                        series_data[float_idx].1.push(val.unwrap_or(f64::NAN));
224                        float_idx += 1;
225                    }
226                    "int2" => {
227                        let val: Option<i16> = row.get(i);
228                        series_data[float_idx].1.push(val.unwrap_or(0) as f64);
229                        float_idx += 1;
230                    }
231                    "int4" => {
232                        let val: Option<i32> = row.get(i);
233                        series_data[float_idx].1.push(val.unwrap_or(0) as f64);
234                        float_idx += 1;
235                    }
236                    "int8" => {
237                        let val: Option<i64> = row.get(i);
238                        series_data[float_idx].1.push(val.unwrap_or(0) as f64);
239                        float_idx += 1;
240                    }
241                    _ => {
242                        // Handle dates, timestamps, and strings as strings
243                        let val: Option<String> = row.try_get(i).ok().flatten();
244                        string_data[string_idx].1.push(val.unwrap_or_default());
245                        string_idx += 1;
246                    }
247                }
248            }
249        }
250
251        // Build DataFrame
252        let mut df_columns: Vec<Column> = Vec::new();
253
254        for (name, values) in series_data {
255            df_columns.push(Column::new(name.into(), values));
256        }
257
258        for (name, values) in string_data {
259            df_columns.push(Column::new(name.into(), values));
260        }
261
262        DataFrame::new(df_columns)
263            .map_err(|e| SigcError::Runtime(format!("Failed to create DataFrame: {}", e)))
264    }
265
266    /// Execute a query that returns a count
267    pub fn query_count(&self, query: &str) -> Result<i64> {
268        match &self.config {
269            ConnectorConfig::Postgres { host, port, database, user, password } => {
270                let conn_str = format!(
271                    "host={} port={} dbname={} user={} password={}",
272                    host, port, database, user, password
273                );
274
275                let mut client = postgres::Client::connect(&conn_str, postgres::NoTls)
276                    .map_err(|e| SigcError::Runtime(format!("Connection failed: {}", e)))?;
277
278                let row = client.query_one(query, &[])
279                    .map_err(|e| SigcError::Runtime(format!("Query failed: {}", e)))?;
280
281                let value: i64 = row.get(0);
282                Ok(value)
283            }
284            _ => Err(SigcError::Runtime("Not a Postgres connector".into())),
285        }
286    }
287
288    /// List available tables
289    pub fn list_tables(&self) -> Result<Vec<String>> {
290        let query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'";
291        let df = self.load(query)?;
292
293        let col = df.column("table_name")
294            .map_err(|e| SigcError::Runtime(format!("Column error: {}", e)))?;
295
296        let tables: Vec<String> = col.str()
297            .map_err(|e| SigcError::Runtime(format!("Cast error: {}", e)))?
298            .into_iter()
299            .filter_map(|s| s.map(|s| s.to_string()))
300            .collect();
301
302        Ok(tables)
303    }
304}
305
306/// Cloud storage connector
307pub struct CloudConnector {
308    config: ConnectorConfig,
309    name: String,
310}
311
312impl CloudConnector {
313    /// Create an S3 connector
314    pub fn s3(bucket: &str, region: &str) -> Self {
315        CloudConnector {
316            config: ConnectorConfig::S3 {
317                bucket: bucket.to_string(),
318                region: region.to_string(),
319                access_key: None,
320                secret_key: None,
321            },
322            name: "s3".to_string(),
323        }
324    }
325
326    /// Create an S3 connector with credentials
327    pub fn s3_with_credentials(bucket: &str, region: &str, access_key: &str, secret_key: &str) -> Self {
328        CloudConnector {
329            config: ConnectorConfig::S3 {
330                bucket: bucket.to_string(),
331                region: region.to_string(),
332                access_key: Some(access_key.to_string()),
333                secret_key: Some(secret_key.to_string()),
334            },
335            name: "s3".to_string(),
336        }
337    }
338
339    /// Create a GCS connector
340    pub fn gcs(bucket: &str, project: &str) -> Self {
341        CloudConnector {
342            config: ConnectorConfig::Gcs {
343                bucket: bucket.to_string(),
344                project: project.to_string(),
345                credentials_path: None,
346            },
347            name: "gcs".to_string(),
348        }
349    }
350
351    /// Create an Azure connector
352    pub fn azure(container: &str, account: &str) -> Self {
353        CloudConnector {
354            config: ConnectorConfig::Azure {
355                container: container.to_string(),
356                account: account.to_string(),
357                access_key: None,
358            },
359            name: "azure".to_string(),
360        }
361    }
362
363    /// Get the full URI for a path
364    fn get_uri(&self, path: &str) -> String {
365        match &self.config {
366            ConnectorConfig::S3 { bucket, .. } => {
367                format!("s3://{}/{}", bucket, path)
368            }
369            ConnectorConfig::Gcs { bucket, .. } => {
370                format!("gs://{}/{}", bucket, path)
371            }
372            ConnectorConfig::Azure { container, account, .. } => {
373                format!("az://{}.blob.core.windows.net/{}/{}", account, container, path)
374            }
375            _ => path.to_string(),
376        }
377    }
378}
379
380impl Connector for CloudConnector {
381    fn load(&self, path: &str) -> Result<DataFrame> {
382        let uri = self.get_uri(path);
383
384        // Determine format from extension
385        let is_parquet = path.ends_with(".parquet") || path.ends_with(".pq");
386        let is_csv = path.ends_with(".csv") || path.ends_with(".csv.gz");
387
388        if is_parquet {
389            // Use object_store for cloud parquet
390            // This is a simplified implementation
391            LazyFrame::scan_parquet(&uri, ScanArgsParquet::default())
392                .map_err(|e| SigcError::Runtime(format!("Failed to scan parquet: {}", e)))?
393                .collect()
394                .map_err(|e| SigcError::Runtime(format!("Failed to collect: {}", e)))
395        } else if is_csv {
396            LazyCsvReader::new(&uri)
397                .finish()
398                .map_err(|e| SigcError::Runtime(format!("Failed to read CSV: {}", e)))?
399                .collect()
400                .map_err(|e| SigcError::Runtime(format!("Failed to collect: {}", e)))
401        } else {
402            Err(SigcError::Runtime(format!("Unknown file format: {}", path)))
403        }
404    }
405
406    fn is_available(&self) -> bool {
407        true // Would check actual connectivity
408    }
409
410    fn name(&self) -> &str {
411        &self.name
412    }
413}
414
415/// Connector registry for managing multiple data sources
416pub struct ConnectorRegistry {
417    connectors: HashMap<String, Box<dyn Connector>>,
418}
419
420impl ConnectorRegistry {
421    /// Create a new empty registry
422    pub fn new() -> Self {
423        ConnectorRegistry {
424            connectors: HashMap::new(),
425        }
426    }
427
428    /// Register a connector
429    pub fn register(&mut self, name: &str, connector: Box<dyn Connector>) {
430        self.connectors.insert(name.to_string(), connector);
431    }
432
433    /// Get a connector by name
434    pub fn get(&self, name: &str) -> Option<&dyn Connector> {
435        self.connectors.get(name).map(|c| c.as_ref())
436    }
437
438    /// Load data using a connector
439    pub fn load(&self, connector_name: &str, path: &str) -> Result<DataFrame> {
440        let connector = self.connectors.get(connector_name)
441            .ok_or_else(|| SigcError::Runtime(format!("Connector not found: {}", connector_name)))?;
442
443        connector.load(path)
444    }
445
446    /// List all registered connectors
447    pub fn list(&self) -> Vec<String> {
448        self.connectors.keys().cloned().collect()
449    }
450
451    /// Check if a connector exists
452    pub fn has(&self, name: &str) -> bool {
453        self.connectors.contains_key(name)
454    }
455}
456
457impl Default for ConnectorRegistry {
458    fn default() -> Self {
459        Self::new()
460    }
461}
462
463/// Environment-based connector configuration
464pub struct ConnectorEnv;
465
466impl ConnectorEnv {
467    /// Create S3 connector from environment variables
468    pub fn s3_from_env(bucket: &str) -> CloudConnector {
469        let region = std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string());
470        let access_key = std::env::var("AWS_ACCESS_KEY_ID").ok();
471        let secret_key = std::env::var("AWS_SECRET_ACCESS_KEY").ok();
472
473        if let (Some(ak), Some(sk)) = (access_key, secret_key) {
474            CloudConnector::s3_with_credentials(bucket, &region, &ak, &sk)
475        } else {
476            CloudConnector::s3(bucket, &region)
477        }
478    }
479
480    /// Create Postgres connector from environment
481    pub fn postgres_from_env() -> Option<SqlConnector> {
482        let host = std::env::var("PGHOST").ok()?;
483        let port: u16 = std::env::var("PGPORT").ok()?.parse().ok()?;
484        let database = std::env::var("PGDATABASE").ok()?;
485        let user = std::env::var("PGUSER").ok()?;
486        let password = std::env::var("PGPASSWORD").ok()?;
487
488        Some(SqlConnector::postgres(&host, port, &database, &user, &password))
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[test]
497    fn test_postgres_connection_string() {
498        let connector = SqlConnector::postgres("localhost", 5432, "testdb", "user", "pass");
499        let conn_str = connector.connection_string();
500        assert!(conn_str.contains("postgresql://"));
501        assert!(conn_str.contains("localhost:5432"));
502    }
503
504    #[test]
505    fn test_s3_uri() {
506        let connector = CloudConnector::s3("my-bucket", "us-east-1");
507        let uri = connector.get_uri("data/prices.parquet");
508        assert_eq!(uri, "s3://my-bucket/data/prices.parquet");
509    }
510
511    #[test]
512    fn test_gcs_uri() {
513        let connector = CloudConnector::gcs("my-bucket", "my-project");
514        let uri = connector.get_uri("data/prices.parquet");
515        assert_eq!(uri, "gs://my-bucket/data/prices.parquet");
516    }
517
518    #[test]
519    fn test_registry() {
520        let mut registry = ConnectorRegistry::new();
521        registry.register("s3_data", Box::new(CloudConnector::s3("bucket", "region")));
522
523        assert!(registry.has("s3_data"));
524        assert!(!registry.has("nonexistent"));
525        assert_eq!(registry.list().len(), 1);
526    }
527}