1use crate::codec::{BackendMessage, PgCodec};
2use crate::message::Message;
3use crate::transaction::PostgresTransaction;
4use futures::{SinkExt, StreamExt};
5use tokio::net::TcpStream;
6use tokio_util::codec::Framed;
7use tracing::{debug, info};
8use yykv_types::{ColumnInfo, DsError, DsResult, DsValue, EnumInfo, SchemaInspector, TableInfo};
9
10#[async_trait::async_trait]
11impl SchemaInspector for PostgresConnection {
12 async fn introspect(&self, schema: Option<&str>) -> DsResult<(Vec<TableInfo>, Vec<EnumInfo>)> {
13 let schema = schema.unwrap_or("public");
14
15 let enum_sql = format!(
17 "SELECT
18 t.typname as enum_name,
19 e.enumlabel as enum_variant
20 FROM pg_type t
21 JOIN pg_enum e ON t.oid = e.enumtypid
22 JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
23 WHERE n.nspname = '{}'
24 ORDER BY t.typname, e.enumsortorder",
25 schema
26 );
27
28 let enum_rows = self.query(&enum_sql, &[]).await?;
29 let mut enums = Vec::new();
30 let mut current_enum: Option<EnumInfo> = None;
31
32 for row in enum_rows {
33 if let DsValue::List(fields) = row {
34 let name = match fields.first() {
35 Some(DsValue::Text(s)) => s.clone(),
36 _ => continue,
37 };
38 let variant = match fields.get(1) {
39 Some(DsValue::Text(s)) => s.clone(),
40 _ => continue,
41 };
42
43 if let Some(ref mut e) = current_enum {
44 if e.name == name {
45 e.variants.push(variant);
46 continue;
47 } else {
48 enums.push(current_enum.take().unwrap());
49 }
50 }
51 current_enum = Some(EnumInfo {
52 name,
53 variants: vec![variant],
54 });
55 }
56 }
57 if let Some(e) = current_enum {
58 enums.push(e);
59 }
60
61 let sql = format!(
63 "SELECT
64 c.relname as table_name,
65 pg_catalog.obj_description(c.oid, 'pg_class') as table_comment
66 FROM pg_catalog.pg_class c
67 JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
68 WHERE n.nspname = '{}' AND c.relkind = 'r' AND c.relname NOT LIKE 'pg_%' AND c.relname NOT LIKE 'sql_%'",
69 schema
70 );
71
72 let table_rows = self.query(&sql, &[]).await?;
73 let mut tables = Vec::new();
74
75 for row in table_rows {
76 if let DsValue::List(fields) = row {
77 let table_name = match fields.first() {
78 Some(DsValue::Text(s)) => s.clone(),
79 _ => continue,
80 };
81 let description = match fields.get(1) {
82 Some(DsValue::Text(s)) => Some(s.clone()),
83 _ => None,
84 };
85
86 let mut columns = Vec::new();
87 let col_sql = format!(
88 "SELECT
89 column_name,
90 data_type,
91 is_nullable,
92 column_default,
93 (SELECT pg_catalog.col_description(c.oid, a.attnum)
94 FROM pg_catalog.pg_class c
95 JOIN pg_catalog.pg_attribute a ON c.oid = a.attrelid
96 WHERE c.relname = '{}' AND a.attname = column_name) as column_comment
97 FROM information_schema.columns
98 WHERE table_schema = '{}' AND table_name = '{}'
99 ORDER BY ordinal_position",
100 table_name, schema, table_name
101 );
102
103 let col_rows = self.query(&col_sql, &[]).await?;
104 for col_row in col_rows {
105 if let DsValue::List(cfields) = col_row {
106 let name = match cfields.first() {
107 Some(DsValue::Text(s)) => s.clone(),
108 _ => continue,
109 };
110 let data_type = match cfields.get(1) {
111 Some(DsValue::Text(s)) => s.clone(),
112 _ => continue,
113 };
114 let is_nullable = match cfields.get(2) {
115 Some(DsValue::Text(s)) => s == "YES",
116 _ => false,
117 };
118 let default = match cfields.get(3) {
119 Some(DsValue::Text(s)) => Some(s.clone()),
120 _ => None,
121 };
122 let col_description = match cfields.get(4) {
123 Some(DsValue::Text(s)) => Some(s.clone()),
124 _ => None,
125 };
126
127 columns.push(ColumnInfo {
128 name,
129 data_type,
130 is_nullable,
131 is_primary_key: false, is_enum: false, foreign_key: None,
134 default,
135 description: col_description,
136 });
137 }
138 }
139
140 tables.push(TableInfo {
141 name: table_name,
142 columns,
143 description,
144 });
145 }
146 }
147
148 Ok((tables, enums))
149 }
150}
151
152type Result<T> = std::result::Result<T, DsError>;
153
154pub struct PostgresConnection {
155 pub url: String,
156}
157
158struct ConnectionConfig {
159 host: String,
160 port: u16,
161 user: String,
162 password: Option<String>,
163 database: String,
164}
165
166impl PostgresConnection {
167 pub fn new(url: String) -> Self {
168 Self { url }
169 }
170
171 fn parse_url(&self) -> Result<ConnectionConfig> {
172 let url = self
173 .url
174 .strip_prefix("postgres://")
175 .or_else(|| self.url.strip_prefix("postgresql://"))
176 .ok_or_else(|| DsError::protocol("Invalid postgres URL scheme"))?;
177
178 let (auth_host, database) = match url.find('/') {
179 Some(i) => (&url[..i], url[i + 1..].to_string()),
180 None => (url, "postgres".to_string()),
181 };
182
183 let (auth, host_port) = match auth_host.find('@') {
184 Some(i) => (&auth_host[..i], &auth_host[i + 1..]),
185 None => ("postgres", auth_host),
186 };
187
188 let (user, password) = match auth.find(':') {
189 Some(i) => (auth[..i].to_string(), Some(auth[i + 1..].to_string())),
190 None => (auth.to_string(), None),
191 };
192
193 let (host, port) = match host_port.find(':') {
194 Some(i) => (
195 host_port[..i].to_string(),
196 host_port[i + 1..]
197 .parse()
198 .map_err(|_| DsError::protocol("Invalid port"))?,
199 ),
200 None => (host_port.to_string(), 5432),
201 };
202
203 Ok(ConnectionConfig {
204 host,
205 port,
206 user,
207 password,
208 database,
209 })
210 }
211
212 async fn connect(&self) -> Result<Framed<TcpStream, PgCodec>> {
213 let config = self.parse_url()?;
214 let addr = format!("{}:{}", config.host, config.port);
215 let stream = TcpStream::connect(addr)
216 .await
217 .map_err(|e| DsError::protocol(format!("Failed to connect to postgres: {}", e)))?;
218
219 let mut framed = Framed::new(stream, PgCodec::new());
220
221 let params = vec![
223 ("user".to_string(), config.user.clone()),
224 ("database".to_string(), config.database.clone()),
225 ("client_encoding".to_string(), "UTF8".to_string()),
226 ];
227
228 framed.send(Message::Startup { params }).await?;
229
230 loop {
232 let msg = framed
233 .next()
234 .await
235 .ok_or_else(|| DsError::protocol("Connection closed during startup"))??;
236
237 match msg {
238 BackendMessage::AuthenticationOk => {
239 debug!("Postgres authentication OK");
240 }
241 BackendMessage::AuthenticationCleartextPassword => {
242 let pass = config.password.as_ref().ok_or_else(|| {
243 DsError::protocol("Password required by server but not provided")
244 })?;
245 framed.send(Message::Password(pass.clone())).await?;
246 }
247 BackendMessage::AuthenticationMD5Password { salt } => {
248 let user = config.user.clone();
249 let pass = config.password.as_ref().ok_or_else(|| {
250 DsError::protocol("Password required by server but not provided")
251 })?;
252
253 let hash1 = md5::compute(format!("{}{}", pass, user));
255 let hash1_hex = hex::encode(hash1.0);
256
257 let mut hash2_input = Vec::new();
258 hash2_input.extend_from_slice(hash1_hex.as_bytes());
259 hash2_input.extend_from_slice(&salt);
260
261 let hash2 = md5::compute(hash2_input);
262 let response = format!("md5{}", hex::encode(hash2.0));
263
264 framed.send(Message::Password(response)).await?;
265 }
266 BackendMessage::ParameterStatus { name, value } => {
267 debug!("Postgres parameter status: {} = {}", name, value);
268 }
269 BackendMessage::BackendKeyData { .. } => {}
270 BackendMessage::ReadyForQuery { .. } => {
271 break;
272 }
273 BackendMessage::ErrorResponse { fields } => {
274 let msg = fields
275 .iter()
276 .find(|(t, _)| *t == b'M')
277 .map(|(_, m)| m.clone())
278 .unwrap_or_default();
279 return Err(DsError::protocol(format!(
280 "Postgres startup error: {}",
281 msg
282 )));
283 }
284 _ => {
285 debug!("Received unexpected message during startup: {:?}", msg);
286 }
287 }
288 }
289
290 Ok(framed)
291 }
292
293 pub async fn execute(&self, sql: &str, params: &[DsValue]) -> Result<u64> {
294 info!("Postgres executing: {}", sql);
295 let mut framed = self.connect().await?;
296
297 if params.is_empty() {
298 framed.send(Message::Query(sql.to_string())).await?;
300 } else {
301 framed
303 .send(Message::Parse {
304 name: "".to_string(),
305 query: sql.to_string(),
306 param_types: vec![0; params.len()], })
308 .await?;
309
310 framed
311 .send(Message::Bind {
312 portal: "".to_string(),
313 statement: "".to_string(),
314 params: params.to_vec(),
315 })
316 .await?;
317
318 framed
319 .send(Message::Execute {
320 portal: "".to_string(),
321 max_rows: 0,
322 })
323 .await?;
324
325 framed.send(Message::Sync).await?;
326 }
327
328 let mut affected_rows = 0;
329 loop {
330 let msg = framed
331 .next()
332 .await
333 .ok_or_else(|| DsError::protocol("Connection closed during execute"))??;
334
335 match msg {
336 BackendMessage::CommandComplete { tag } => {
337 if let Some(s) = tag.split_whitespace().last() {
338 affected_rows = s.parse().unwrap_or(0);
339 }
340 }
341 BackendMessage::ReadyForQuery { .. } => break,
342 BackendMessage::ErrorResponse { fields } => {
343 let msg = fields
344 .iter()
345 .find(|(t, _)| *t == b'M')
346 .map(|(_, m)| m.clone())
347 .unwrap_or_default();
348 return Err(DsError::protocol(format!(
349 "Postgres execute error: {}",
350 msg
351 )));
352 }
353 _ => {}
354 }
355 }
356
357 Ok(affected_rows)
358 }
359
360 pub async fn query(&self, sql: &str, params: &[DsValue]) -> Result<Vec<DsValue>> {
361 info!("Postgres querying: {}", sql);
362
363 let mut framed = self.connect().await?;
364
365 if params.is_empty() {
366 framed.send(Message::Query(sql.to_string())).await?;
368 } else {
369 framed
371 .send(Message::Parse {
372 name: "".to_string(),
373 query: sql.to_string(),
374 param_types: vec![0; params.len()], })
376 .await?;
377
378 framed
379 .send(Message::Bind {
380 portal: "".to_string(),
381 statement: "".to_string(),
382 params: params.to_vec(),
383 })
384 .await?;
385
386 framed
387 .send(Message::Execute {
388 portal: "".to_string(),
389 max_rows: 0,
390 })
391 .await?;
392
393 framed.send(Message::Sync).await?;
394 }
395
396 let mut results = Vec::new();
397 loop {
398 let msg = framed
399 .next()
400 .await
401 .ok_or_else(|| DsError::protocol("Connection closed during query"))??;
402
403 match msg {
404 BackendMessage::RowDescription { .. } => {}
405 BackendMessage::DataRow { values } => {
406 debug!("Postgres received row with {} columns", values.len());
407 let mut row = Vec::with_capacity(values.len());
408 for val in values {
409 match val {
410 Some(bytes) => {
411 let s = String::from_utf8_lossy(&bytes).to_string();
412 row.push(DsValue::Text(s));
413 }
414 None => row.push(DsValue::Null),
415 }
416 }
417 results.push(DsValue::List(row));
418 }
419 BackendMessage::ReadyForQuery { .. } => break,
420 BackendMessage::ErrorResponse { fields } => {
421 let msg = fields
422 .iter()
423 .find(|(t, _)| *t == b'M')
424 .map(|(_, m)| m.clone())
425 .unwrap_or_default();
426 return Err(DsError::protocol(format!("Postgres query error: {}", msg)));
427 }
428 _ => {}
429 }
430 }
431
432 Ok(results)
433 }
434
435 pub async fn begin_transaction(self) -> Result<PostgresTransaction> {
436 Ok(PostgresTransaction::new(self))
437 }
438}