1use polars::prelude::*;
6use sig_types::{Result, SigcError};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub enum ConnectorConfig {
12 Postgres {
14 host: String,
15 port: u16,
16 database: String,
17 user: String,
18 password: String,
19 },
20 Snowflake {
22 account: String,
23 warehouse: String,
24 database: String,
25 schema: String,
26 user: String,
27 password: String,
28 },
29 S3 {
31 bucket: String,
32 region: String,
33 access_key: Option<String>,
34 secret_key: Option<String>,
35 },
36 Gcs {
38 bucket: String,
39 project: String,
40 credentials_path: Option<String>,
41 },
42 Azure {
44 container: String,
45 account: String,
46 access_key: Option<String>,
47 },
48}
49
50pub trait Connector: Send + Sync {
52 fn load(&self, path: &str) -> Result<DataFrame>;
54
55 fn is_available(&self) -> bool;
57
58 fn name(&self) -> &str;
60}
61
62pub struct SqlConnector {
64 config: ConnectorConfig,
65 name: String,
66}
67
68impl SqlConnector {
69 pub fn postgres(host: &str, port: u16, database: &str, user: &str, password: &str) -> Self {
71 SqlConnector {
72 config: ConnectorConfig::Postgres {
73 host: host.to_string(),
74 port,
75 database: database.to_string(),
76 user: user.to_string(),
77 password: password.to_string(),
78 },
79 name: "postgres".to_string(),
80 }
81 }
82
83 pub fn snowflake(
85 account: &str,
86 warehouse: &str,
87 database: &str,
88 schema: &str,
89 user: &str,
90 password: &str,
91 ) -> Self {
92 SqlConnector {
93 config: ConnectorConfig::Snowflake {
94 account: account.to_string(),
95 warehouse: warehouse.to_string(),
96 database: database.to_string(),
97 schema: schema.to_string(),
98 user: user.to_string(),
99 password: password.to_string(),
100 },
101 name: "snowflake".to_string(),
102 }
103 }
104
105 #[allow(dead_code)]
107 fn connection_string(&self) -> String {
108 match &self.config {
109 ConnectorConfig::Postgres { host, port, database, user, password } => {
110 format!("postgresql://{}:{}@{}:{}/{}", user, password, host, port, database)
111 }
112 ConnectorConfig::Snowflake { account, warehouse, database, schema, user, password } => {
113 format!(
114 "snowflake://{}:{}@{}/{}/{}?warehouse={}",
115 user, password, account, database, schema, warehouse
116 )
117 }
118 _ => String::new(),
119 }
120 }
121}
122
123impl Connector for SqlConnector {
124 fn load(&self, query: &str) -> Result<DataFrame> {
125 match &self.config {
126 ConnectorConfig::Postgres { host, port, database, user, password } => {
127 self.load_postgres(host, *port, database, user, password, query)
128 }
129 ConnectorConfig::Snowflake { .. } => {
130 Err(SigcError::Runtime(
132 "Snowflake connector requires snowflake-connector. Use ODBC or REST API.".into()
133 ))
134 }
135 _ => Err(SigcError::Runtime("Invalid config for SQL connector".into())),
136 }
137 }
138
139 fn is_available(&self) -> bool {
140 match &self.config {
141 ConnectorConfig::Postgres { host, port, database, user, password } => {
142 let conn_str = format!(
143 "host={} port={} dbname={} user={} password={}",
144 host, port, database, user, password
145 );
146 postgres::Client::connect(&conn_str, postgres::NoTls).is_ok()
147 }
148 _ => false,
149 }
150 }
151
152 fn name(&self) -> &str {
153 &self.name
154 }
155}
156
157impl SqlConnector {
158 fn load_postgres(
160 &self,
161 host: &str,
162 port: u16,
163 database: &str,
164 user: &str,
165 password: &str,
166 query: &str,
167 ) -> Result<DataFrame> {
168 let conn_str = format!(
169 "host={} port={} dbname={} user={} password={}",
170 host, port, database, user, password
171 );
172
173 let mut client = postgres::Client::connect(&conn_str, postgres::NoTls)
174 .map_err(|e| SigcError::Runtime(format!("Failed to connect to Postgres: {}", e)))?;
175
176 let rows = client.query(query, &[])
177 .map_err(|e| SigcError::Runtime(format!("Query failed: {}", e)))?;
178
179 if rows.is_empty() {
180 return Err(SigcError::Runtime("Query returned no rows".into()));
181 }
182
183 let columns = rows[0].columns();
185 let mut series_data: Vec<(String, Vec<f64>)> = Vec::new();
186 let mut string_data: Vec<(String, Vec<String>)> = Vec::new();
187
188 for col in columns {
190 let name = col.name().to_string();
191 let type_name = col.type_().name();
192
193 match type_name {
194 "float4" | "float8" | "numeric" | "int2" | "int4" | "int8" => {
195 series_data.push((name, Vec::with_capacity(rows.len())));
196 }
197 "text" | "varchar" | "date" | "timestamp" | "timestamptz" => {
198 string_data.push((name, Vec::with_capacity(rows.len())));
199 }
200 _ => {
201 string_data.push((name, Vec::with_capacity(rows.len())));
203 }
204 }
205 }
206
207 for row in &rows {
209 let mut float_idx = 0;
210 let mut string_idx = 0;
211
212 for (i, col) in columns.iter().enumerate() {
213 let type_name = col.type_().name();
214
215 match type_name {
216 "float4" => {
217 let val: Option<f32> = row.get(i);
218 series_data[float_idx].1.push(val.map(|v| v as f64).unwrap_or(f64::NAN));
219 float_idx += 1;
220 }
221 "float8" | "numeric" => {
222 let val: Option<f64> = row.get(i);
223 series_data[float_idx].1.push(val.unwrap_or(f64::NAN));
224 float_idx += 1;
225 }
226 "int2" => {
227 let val: Option<i16> = row.get(i);
228 series_data[float_idx].1.push(val.unwrap_or(0) as f64);
229 float_idx += 1;
230 }
231 "int4" => {
232 let val: Option<i32> = row.get(i);
233 series_data[float_idx].1.push(val.unwrap_or(0) as f64);
234 float_idx += 1;
235 }
236 "int8" => {
237 let val: Option<i64> = row.get(i);
238 series_data[float_idx].1.push(val.unwrap_or(0) as f64);
239 float_idx += 1;
240 }
241 _ => {
242 let val: Option<String> = row.try_get(i).ok().flatten();
244 string_data[string_idx].1.push(val.unwrap_or_default());
245 string_idx += 1;
246 }
247 }
248 }
249 }
250
251 let mut df_columns: Vec<Column> = Vec::new();
253
254 for (name, values) in series_data {
255 df_columns.push(Column::new(name.into(), values));
256 }
257
258 for (name, values) in string_data {
259 df_columns.push(Column::new(name.into(), values));
260 }
261
262 DataFrame::new(df_columns)
263 .map_err(|e| SigcError::Runtime(format!("Failed to create DataFrame: {}", e)))
264 }
265
266 pub fn query_count(&self, query: &str) -> Result<i64> {
268 match &self.config {
269 ConnectorConfig::Postgres { host, port, database, user, password } => {
270 let conn_str = format!(
271 "host={} port={} dbname={} user={} password={}",
272 host, port, database, user, password
273 );
274
275 let mut client = postgres::Client::connect(&conn_str, postgres::NoTls)
276 .map_err(|e| SigcError::Runtime(format!("Connection failed: {}", e)))?;
277
278 let row = client.query_one(query, &[])
279 .map_err(|e| SigcError::Runtime(format!("Query failed: {}", e)))?;
280
281 let value: i64 = row.get(0);
282 Ok(value)
283 }
284 _ => Err(SigcError::Runtime("Not a Postgres connector".into())),
285 }
286 }
287
288 pub fn list_tables(&self) -> Result<Vec<String>> {
290 let query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'";
291 let df = self.load(query)?;
292
293 let col = df.column("table_name")
294 .map_err(|e| SigcError::Runtime(format!("Column error: {}", e)))?;
295
296 let tables: Vec<String> = col.str()
297 .map_err(|e| SigcError::Runtime(format!("Cast error: {}", e)))?
298 .into_iter()
299 .filter_map(|s| s.map(|s| s.to_string()))
300 .collect();
301
302 Ok(tables)
303 }
304}
305
306pub struct CloudConnector {
308 config: ConnectorConfig,
309 name: String,
310}
311
312impl CloudConnector {
313 pub fn s3(bucket: &str, region: &str) -> Self {
315 CloudConnector {
316 config: ConnectorConfig::S3 {
317 bucket: bucket.to_string(),
318 region: region.to_string(),
319 access_key: None,
320 secret_key: None,
321 },
322 name: "s3".to_string(),
323 }
324 }
325
326 pub fn s3_with_credentials(bucket: &str, region: &str, access_key: &str, secret_key: &str) -> Self {
328 CloudConnector {
329 config: ConnectorConfig::S3 {
330 bucket: bucket.to_string(),
331 region: region.to_string(),
332 access_key: Some(access_key.to_string()),
333 secret_key: Some(secret_key.to_string()),
334 },
335 name: "s3".to_string(),
336 }
337 }
338
339 pub fn gcs(bucket: &str, project: &str) -> Self {
341 CloudConnector {
342 config: ConnectorConfig::Gcs {
343 bucket: bucket.to_string(),
344 project: project.to_string(),
345 credentials_path: None,
346 },
347 name: "gcs".to_string(),
348 }
349 }
350
351 pub fn azure(container: &str, account: &str) -> Self {
353 CloudConnector {
354 config: ConnectorConfig::Azure {
355 container: container.to_string(),
356 account: account.to_string(),
357 access_key: None,
358 },
359 name: "azure".to_string(),
360 }
361 }
362
363 fn get_uri(&self, path: &str) -> String {
365 match &self.config {
366 ConnectorConfig::S3 { bucket, .. } => {
367 format!("s3://{}/{}", bucket, path)
368 }
369 ConnectorConfig::Gcs { bucket, .. } => {
370 format!("gs://{}/{}", bucket, path)
371 }
372 ConnectorConfig::Azure { container, account, .. } => {
373 format!("az://{}.blob.core.windows.net/{}/{}", account, container, path)
374 }
375 _ => path.to_string(),
376 }
377 }
378}
379
380impl Connector for CloudConnector {
381 fn load(&self, path: &str) -> Result<DataFrame> {
382 let uri = self.get_uri(path);
383
384 let is_parquet = path.ends_with(".parquet") || path.ends_with(".pq");
386 let is_csv = path.ends_with(".csv") || path.ends_with(".csv.gz");
387
388 if is_parquet {
389 LazyFrame::scan_parquet(&uri, ScanArgsParquet::default())
392 .map_err(|e| SigcError::Runtime(format!("Failed to scan parquet: {}", e)))?
393 .collect()
394 .map_err(|e| SigcError::Runtime(format!("Failed to collect: {}", e)))
395 } else if is_csv {
396 LazyCsvReader::new(&uri)
397 .finish()
398 .map_err(|e| SigcError::Runtime(format!("Failed to read CSV: {}", e)))?
399 .collect()
400 .map_err(|e| SigcError::Runtime(format!("Failed to collect: {}", e)))
401 } else {
402 Err(SigcError::Runtime(format!("Unknown file format: {}", path)))
403 }
404 }
405
406 fn is_available(&self) -> bool {
407 true }
409
410 fn name(&self) -> &str {
411 &self.name
412 }
413}
414
415pub struct ConnectorRegistry {
417 connectors: HashMap<String, Box<dyn Connector>>,
418}
419
420impl ConnectorRegistry {
421 pub fn new() -> Self {
423 ConnectorRegistry {
424 connectors: HashMap::new(),
425 }
426 }
427
428 pub fn register(&mut self, name: &str, connector: Box<dyn Connector>) {
430 self.connectors.insert(name.to_string(), connector);
431 }
432
433 pub fn get(&self, name: &str) -> Option<&dyn Connector> {
435 self.connectors.get(name).map(|c| c.as_ref())
436 }
437
438 pub fn load(&self, connector_name: &str, path: &str) -> Result<DataFrame> {
440 let connector = self.connectors.get(connector_name)
441 .ok_or_else(|| SigcError::Runtime(format!("Connector not found: {}", connector_name)))?;
442
443 connector.load(path)
444 }
445
446 pub fn list(&self) -> Vec<String> {
448 self.connectors.keys().cloned().collect()
449 }
450
451 pub fn has(&self, name: &str) -> bool {
453 self.connectors.contains_key(name)
454 }
455}
456
457impl Default for ConnectorRegistry {
458 fn default() -> Self {
459 Self::new()
460 }
461}
462
463pub struct ConnectorEnv;
465
466impl ConnectorEnv {
467 pub fn s3_from_env(bucket: &str) -> CloudConnector {
469 let region = std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".to_string());
470 let access_key = std::env::var("AWS_ACCESS_KEY_ID").ok();
471 let secret_key = std::env::var("AWS_SECRET_ACCESS_KEY").ok();
472
473 if let (Some(ak), Some(sk)) = (access_key, secret_key) {
474 CloudConnector::s3_with_credentials(bucket, ®ion, &ak, &sk)
475 } else {
476 CloudConnector::s3(bucket, ®ion)
477 }
478 }
479
480 pub fn postgres_from_env() -> Option<SqlConnector> {
482 let host = std::env::var("PGHOST").ok()?;
483 let port: u16 = std::env::var("PGPORT").ok()?.parse().ok()?;
484 let database = std::env::var("PGDATABASE").ok()?;
485 let user = std::env::var("PGUSER").ok()?;
486 let password = std::env::var("PGPASSWORD").ok()?;
487
488 Some(SqlConnector::postgres(&host, port, &database, &user, &password))
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn test_postgres_connection_string() {
498 let connector = SqlConnector::postgres("localhost", 5432, "testdb", "user", "pass");
499 let conn_str = connector.connection_string();
500 assert!(conn_str.contains("postgresql://"));
501 assert!(conn_str.contains("localhost:5432"));
502 }
503
504 #[test]
505 fn test_s3_uri() {
506 let connector = CloudConnector::s3("my-bucket", "us-east-1");
507 let uri = connector.get_uri("data/prices.parquet");
508 assert_eq!(uri, "s3://my-bucket/data/prices.parquet");
509 }
510
511 #[test]
512 fn test_gcs_uri() {
513 let connector = CloudConnector::gcs("my-bucket", "my-project");
514 let uri = connector.get_uri("data/prices.parquet");
515 assert_eq!(uri, "gs://my-bucket/data/prices.parquet");
516 }
517
518 #[test]
519 fn test_registry() {
520 let mut registry = ConnectorRegistry::new();
521 registry.register("s3_data", Box::new(CloudConnector::s3("bucket", "region")));
522
523 assert!(registry.has("s3_data"));
524 assert!(!registry.has("nonexistent"));
525 assert_eq!(registry.list().len(), 1);
526 }
527}