1use datafusion::arrow::array::RecordBatch;
2use datafusion::arrow::array::*;
3use datafusion::arrow::datatypes::{DataType, Field, Schema};
4use postgres::{Client, NoTls};
5use std::sync::Arc;
6
7use crate::engine::DataEngine;
8
9fn pg_type_to_arrow(pg_type: &postgres::types::Type) -> DataType {
11 match *pg_type {
12 postgres::types::Type::BOOL => DataType::Boolean,
13 postgres::types::Type::INT2 => DataType::Int16,
14 postgres::types::Type::INT4 => DataType::Int32,
15 postgres::types::Type::INT8 => DataType::Int64,
16 postgres::types::Type::FLOAT4 => DataType::Float32,
17 postgres::types::Type::FLOAT8 => DataType::Float64,
18 postgres::types::Type::TEXT
19 | postgres::types::Type::VARCHAR
20 | postgres::types::Type::BPCHAR => DataType::Utf8,
21 _ => DataType::Utf8, }
23}
24
25impl DataEngine {
26 pub fn read_postgres(
28 &self,
29 conn_str: &str,
30 table_name: &str,
31 ) -> Result<datafusion::prelude::DataFrame, String> {
32 let mut client = Client::connect(conn_str, NoTls)
33 .map_err(|e| format!("PostgreSQL connection error: {e}"))?;
34
35 let query = format!("SELECT * FROM {table_name}");
36 let rows = client
37 .query(&query, &[])
38 .map_err(|e| format!("PostgreSQL query error: {e}"))?;
39
40 if rows.is_empty() {
41 return Err(format!("Table '{table_name}' is empty or does not exist"));
42 }
43
44 let columns = rows[0].columns();
45 let fields: Vec<Field> = columns
46 .iter()
47 .map(|col| Field::new(col.name(), pg_type_to_arrow(col.type_()), true))
48 .collect();
49 let schema = Arc::new(Schema::new(fields));
50
51 let mut arrays: Vec<Arc<dyn Array>> = Vec::new();
52 for (col_idx, col) in columns.iter().enumerate() {
53 let arrow_type = pg_type_to_arrow(col.type_());
54 let array: Arc<dyn Array> = match arrow_type {
55 DataType::Boolean => {
56 let values: Vec<Option<bool>> = rows.iter().map(|r| r.get(col_idx)).collect();
57 Arc::new(BooleanArray::from(values))
58 }
59 DataType::Int16 => {
60 let values: Vec<Option<i16>> = rows.iter().map(|r| r.get(col_idx)).collect();
61 Arc::new(Int16Array::from(values))
62 }
63 DataType::Int32 => {
64 let values: Vec<Option<i32>> = rows.iter().map(|r| r.get(col_idx)).collect();
65 Arc::new(Int32Array::from(values))
66 }
67 DataType::Int64 => {
68 let values: Vec<Option<i64>> = rows.iter().map(|r| r.get(col_idx)).collect();
69 Arc::new(Int64Array::from(values))
70 }
71 DataType::Float32 => {
72 let values: Vec<Option<f32>> = rows.iter().map(|r| r.get(col_idx)).collect();
73 Arc::new(Float32Array::from(values))
74 }
75 DataType::Float64 => {
76 let values: Vec<Option<f64>> = rows.iter().map(|r| r.get(col_idx)).collect();
77 Arc::new(Float64Array::from(values))
78 }
79 _ => {
80 let values: Vec<Option<String>> = rows
81 .iter()
82 .map(|r| r.try_get::<_, String>(col_idx).ok())
83 .collect();
84 Arc::new(StringArray::from(values))
85 }
86 };
87 arrays.push(array);
88 }
89
90 let batch = RecordBatch::try_new(schema, arrays)
91 .map_err(|e| format!("Arrow RecordBatch creation error: {e}"))?;
92
93 self.register_batch(table_name, batch)?;
94
95 self.rt
96 .block_on(self.ctx.table(table_name))
97 .map_err(|e| format!("Table reference error: {e}"))
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 #[ignore] fn test_read_postgres() {
108 let engine = DataEngine::new();
109 let df = engine
110 .read_postgres(
111 "host=localhost user=postgres password=postgres dbname=testdb",
112 "users",
113 )
114 .unwrap();
115 let batches = engine.collect(df).unwrap();
116 assert!(!batches.is_empty());
117 }
118}