toasty_driver_postgresql/
lib.rs1#![warn(missing_docs)]
2
3mod statement_cache;
15#[cfg(feature = "tls")]
16mod tls;
17mod r#type;
18mod value;
19
20pub(crate) use value::Value;
21
22use async_trait::async_trait;
23use percent_encoding::percent_decode_str;
24use std::{borrow::Cow, sync::Arc};
25use toasty_core::{
26 Result, Schema,
27 driver::{Capability, Driver, ExecResponse, Operation},
28 schema::db::{self, Migration, SchemaDiff, Table},
29 stmt,
30 stmt::ValueRecord,
31};
32use toasty_sql::{self as sql, TypedValue};
33use tokio_postgres::{Client, Config, Socket, tls::MakeTlsConnect, types::ToSql};
34use url::Url;
35
36use crate::{statement_cache::StatementCache, r#type::TypeExt};
37
38#[derive(Debug)]
48pub struct PostgreSQL {
49 url: String,
50 config: Config,
51 #[cfg(feature = "tls")]
52 tls: Option<tls::MakeRustlsConnect>,
53}
54
55impl PostgreSQL {
56 pub fn new(url: impl Into<String>) -> Result<Self> {
58 let url_str = url.into();
59 let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
60
61 if !matches!(url.scheme(), "postgresql" | "postgres") {
62 return Err(toasty_core::Error::invalid_connection_url(format!(
63 "connection URL does not have a `postgresql` scheme; url={}",
64 url
65 )));
66 }
67
68 let host = url.host_str().ok_or_else(|| {
69 toasty_core::Error::invalid_connection_url(format!(
70 "missing host in connection URL; url={}",
71 url
72 ))
73 })?;
74
75 if url.path().is_empty() {
76 return Err(toasty_core::Error::invalid_connection_url(format!(
77 "no database specified - missing path in connection URL; url={}",
78 url
79 )));
80 }
81
82 let mut config = Config::new();
83 config.host(host);
84
85 let dbname = percent_decode_str(url.path().trim_start_matches('/'))
86 .decode_utf8()
87 .map_err(|_| {
88 toasty_core::Error::invalid_connection_url("database name is not valid UTF-8")
89 })?;
90 config.dbname(&*dbname);
91
92 if let Some(port) = url.port() {
93 config.port(port);
94 }
95
96 if !url.username().is_empty() {
97 let user = percent_decode_str(url.username())
98 .decode_utf8()
99 .map_err(|_| {
100 toasty_core::Error::invalid_connection_url("username is not valid UTF-8")
101 })?;
102 config.user(&*user);
103 }
104
105 if let Some(password) = url.password() {
106 config.password(percent_decode_str(password).collect::<Vec<u8>>());
107 }
108
109 #[cfg(feature = "tls")]
110 let tls = tls::configure_tls(&url, &mut config)?;
111
112 #[cfg(not(feature = "tls"))]
113 for (key, value) in url.query_pairs() {
114 if key == "sslmode" && value != "disable" {
115 return Err(toasty_core::Error::invalid_connection_url(
116 "TLS not available: compile with the `tls` feature",
117 ));
118 }
119 }
120
121 Ok(Self {
122 url: url_str,
123 config,
124 #[cfg(feature = "tls")]
125 tls,
126 })
127 }
128
129 async fn connect_with_config(&self, config: Config) -> Result<Connection> {
130 #[cfg(feature = "tls")]
131 if let Some(ref tls) = self.tls {
132 return Connection::connect(config, tls.clone()).await;
133 }
134 Connection::connect(config, tokio_postgres::NoTls).await
135 }
136}
137
138#[async_trait]
139impl Driver for PostgreSQL {
140 fn url(&self) -> Cow<'_, str> {
141 Cow::Borrowed(&self.url)
142 }
143
144 fn capability(&self) -> &'static Capability {
145 &Capability::POSTGRESQL
146 }
147
148 async fn connect(&self) -> toasty_core::Result<Box<dyn toasty_core::driver::Connection>> {
149 Ok(Box::new(
150 self.connect_with_config(self.config.clone()).await?,
151 ))
152 }
153
154 fn generate_migration(&self, schema_diff: &SchemaDiff<'_>) -> Migration {
155 let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::POSTGRESQL);
156
157 let sql_strings: Vec<String> = statements
158 .iter()
159 .map(|stmt| {
160 let mut params = Vec::<TypedValue>::new();
161 let sql = sql::Serializer::postgresql(stmt.schema())
162 .serialize(stmt.statement(), &mut params);
163 assert!(
164 params.is_empty(),
165 "migration statements should not have parameters"
166 );
167 sql
168 })
169 .collect();
170
171 Migration::new_sql(sql_strings.join("\n"))
172 }
173
174 async fn reset_db(&self) -> toasty_core::Result<()> {
175 let dbname = self
176 .config
177 .get_dbname()
178 .ok_or_else(|| {
179 toasty_core::Error::invalid_connection_url("no database name configured")
180 })?
181 .to_string();
182
183 let temp_dbname = "__toasty_reset_temp";
185
186 let connect = |dbname: &str| {
187 let mut config = self.config.clone();
188 config.dbname(dbname);
189 self.connect_with_config(config)
190 };
191
192 let conn = connect(&dbname).await?;
194 conn.client
195 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
196 .await
197 .map_err(toasty_core::Error::driver_operation_failed)?;
198 conn.client
199 .execute(&format!("CREATE DATABASE \"{}\"", temp_dbname), &[])
200 .await
201 .map_err(toasty_core::Error::driver_operation_failed)?;
202 drop(conn);
203
204 let conn = connect(temp_dbname).await?;
206 conn.client
207 .execute(
208 "SELECT pg_terminate_backend(pid) \
209 FROM pg_stat_activity \
210 WHERE datname = $1 AND pid <> pg_backend_pid()",
211 &[&dbname],
212 )
213 .await
214 .map_err(toasty_core::Error::driver_operation_failed)?;
215 conn.client
216 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", dbname), &[])
217 .await
218 .map_err(toasty_core::Error::driver_operation_failed)?;
219 conn.client
220 .execute(&format!("CREATE DATABASE \"{}\"", dbname), &[])
221 .await
222 .map_err(toasty_core::Error::driver_operation_failed)?;
223 drop(conn);
224
225 let conn = connect(&dbname).await?;
227 conn.client
228 .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
229 .await
230 .map_err(toasty_core::Error::driver_operation_failed)?;
231
232 Ok(())
233 }
234}
235
236#[derive(Debug)]
238pub struct Connection {
239 client: Client,
240 statement_cache: StatementCache,
241}
242
243impl Connection {
244 pub fn new(client: Client) -> Self {
246 Self {
247 client,
248 statement_cache: StatementCache::new(100),
249 }
250 }
251
252 pub async fn connect<T>(config: Config, tls: T) -> Result<Self>
256 where
257 T: MakeTlsConnect<Socket> + 'static,
258 T::Stream: Send,
259 {
260 let (client, connection) = config
261 .connect(tls)
262 .await
263 .map_err(toasty_core::Error::driver_operation_failed)?;
264
265 tokio::spawn(async move {
266 if let Err(e) = connection.await {
267 eprintln!("connection error: {e}");
268 }
269 });
270
271 Ok(Self::new(client))
272 }
273
274 pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
276 let serializer = sql::Serializer::postgresql(schema);
277
278 let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
279 let sql = serializer.serialize(
280 &sql::Statement::create_table(table, &Capability::POSTGRESQL),
281 &mut params,
282 );
283
284 assert!(
285 params.is_empty(),
286 "creating a table shouldn't involve any parameters"
287 );
288
289 self.client
290 .execute(&sql, &[])
291 .await
292 .map_err(toasty_core::Error::driver_operation_failed)?;
293
294 for index in &table.indices {
297 if index.primary_key {
298 continue;
299 }
300
301 let sql = serializer.serialize(&sql::Statement::create_index(index), &mut params);
302
303 assert!(
304 params.is_empty(),
305 "creating an index shouldn't involve any parameters"
306 );
307
308 self.client
309 .execute(&sql, &[])
310 .await
311 .map_err(toasty_core::Error::driver_operation_failed)?;
312 }
313
314 Ok(())
315 }
316}
317
318impl From<Client> for Connection {
319 fn from(client: Client) -> Self {
320 Self {
321 client,
322 statement_cache: StatementCache::new(100),
323 }
324 }
325}
326
327#[async_trait]
328impl toasty_core::driver::Connection for Connection {
329 async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
330 tracing::trace!(driver = "postgresql", op = %op.name(), "driver exec");
331
332 if let Operation::Transaction(ref t) = op {
333 let sql = sql::Serializer::postgresql(&schema.db).serialize_transaction(t);
334 self.client.batch_execute(&sql).await.map_err(|e| {
335 if let Some(db_err) = e.as_db_error() {
336 match db_err.code().code() {
337 "40001" => toasty_core::Error::serialization_failure(db_err.message()),
338 "25006" => toasty_core::Error::read_only_transaction(db_err.message()),
339 _ => toasty_core::Error::driver_operation_failed(e),
340 }
341 } else {
342 toasty_core::Error::driver_operation_failed(e)
343 }
344 })?;
345 return Ok(ExecResponse::count(0));
346 }
347
348 let (sql, ret_tys): (sql::Statement, _) = match op {
349 Operation::Insert(op) => (op.stmt.into(), None),
350 Operation::QuerySql(query) => {
351 assert!(
352 query.last_insert_id_hack.is_none(),
353 "last_insert_id_hack is MySQL-specific and should not be set for PostgreSQL"
354 );
355 (query.stmt.into(), query.ret)
356 }
357 op => todo!("op={:#?}", op),
358 };
359
360 let width = sql.returning_len();
361
362 let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
363 let sql_as_str = sql::Serializer::postgresql(&schema.db).serialize(&sql, &mut params);
364
365 tracing::debug!(db.system = "postgresql", db.statement = %sql_as_str, params = params.len(), "executing SQL");
366
367 let param_types = params
368 .iter()
369 .map(|typed_value| typed_value.infer_ty().to_postgres_type())
370 .collect::<Vec<_>>();
371
372 let values: Vec<_> = params.into_iter().map(|tv| Value::from(tv.value)).collect();
373 let params = values
374 .iter()
375 .map(|param| param as &(dyn ToSql + Sync))
376 .collect::<Vec<_>>();
377
378 let statement = self
379 .statement_cache
380 .prepare_typed(&mut self.client, &sql_as_str, ¶m_types)
381 .await
382 .map_err(toasty_core::Error::driver_operation_failed)?;
383
384 if width.is_none() {
385 let count = self
386 .client
387 .execute(&statement, ¶ms)
388 .await
389 .map_err(toasty_core::Error::driver_operation_failed)?;
390 return Ok(ExecResponse::count(count));
391 }
392
393 let rows = self
394 .client
395 .query(&statement, ¶ms)
396 .await
397 .map_err(toasty_core::Error::driver_operation_failed)?;
398
399 if width.is_none() {
400 let [row] = &rows[..] else { todo!() };
401 let total = row.get::<usize, i64>(0);
402 let condition_matched = row.get::<usize, i64>(1);
403
404 if total == condition_matched {
405 Ok(ExecResponse::count(total as _))
406 } else {
407 Err(toasty_core::Error::condition_failed(
408 "update condition did not match",
409 ))
410 }
411 } else {
412 let ret_tys = ret_tys.as_ref().unwrap().clone();
413 let results = rows.into_iter().map(move |row| {
414 let mut results = Vec::new();
415 for (i, column) in row.columns().iter().enumerate() {
416 results.push(Value::from_sql(i, &row, column, &ret_tys[i]).into_inner());
417 }
418
419 Ok(ValueRecord::from_vec(results))
420 });
421
422 Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
423 results,
424 )))
425 }
426 }
427
428 async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
429 for table in &schema.db.tables {
430 tracing::debug!(table = %table.name, "creating table");
431 self.create_table(&schema.db, table).await?;
432 }
433 Ok(())
434 }
435
436 async fn applied_migrations(
437 &mut self,
438 ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
439 self.client
441 .execute(
442 "CREATE TABLE IF NOT EXISTS __toasty_migrations (
443 id BIGINT PRIMARY KEY,
444 name TEXT NOT NULL,
445 applied_at TIMESTAMP NOT NULL
446 )",
447 &[],
448 )
449 .await
450 .map_err(toasty_core::Error::driver_operation_failed)?;
451
452 let rows = self
454 .client
455 .query(
456 "SELECT id FROM __toasty_migrations ORDER BY applied_at",
457 &[],
458 )
459 .await
460 .map_err(toasty_core::Error::driver_operation_failed)?;
461
462 Ok(rows
463 .iter()
464 .map(|row| {
465 let id: i64 = row.get(0);
466 toasty_core::schema::db::AppliedMigration::new(id as u64)
467 })
468 .collect())
469 }
470
471 async fn apply_migration(
472 &mut self,
473 id: u64,
474 name: &str,
475 migration: &toasty_core::schema::db::Migration,
476 ) -> Result<()> {
477 tracing::info!(id = id, name = %name, "applying migration");
478 self.client
480 .execute(
481 "CREATE TABLE IF NOT EXISTS __toasty_migrations (
482 id BIGINT PRIMARY KEY,
483 name TEXT NOT NULL,
484 applied_at TIMESTAMP NOT NULL
485 )",
486 &[],
487 )
488 .await
489 .map_err(toasty_core::Error::driver_operation_failed)?;
490
491 let transaction = self
493 .client
494 .transaction()
495 .await
496 .map_err(toasty_core::Error::driver_operation_failed)?;
497
498 for statement in migration.statements() {
500 if let Err(e) = transaction
501 .batch_execute(statement)
502 .await
503 .map_err(toasty_core::Error::driver_operation_failed)
504 {
505 transaction
506 .rollback()
507 .await
508 .map_err(toasty_core::Error::driver_operation_failed)?;
509 return Err(e);
510 }
511 }
512
513 if let Err(e) = transaction
515 .execute(
516 "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES ($1, $2, NOW())",
517 &[&(id as i64), &name],
518 )
519 .await
520 .map_err(toasty_core::Error::driver_operation_failed)
521 {
522 transaction
523 .rollback()
524 .await
525 .map_err(toasty_core::Error::driver_operation_failed)?;
526 return Err(e);
527 }
528
529 transaction
531 .commit()
532 .await
533 .map_err(toasty_core::Error::driver_operation_failed)?;
534 Ok(())
535 }
536}