Skip to main content

toasty_driver_postgresql/
lib.rs

1#![warn(missing_docs)]
2
3//! Toasty driver for [PostgreSQL](https://www.postgresql.org/) using
4//! [`tokio-postgres`](https://docs.rs/tokio-postgres).
5//!
6//! # Examples
7//!
8//! ```no_run
9//! use toasty_driver_postgresql::PostgreSQL;
10//!
11//! let driver = PostgreSQL::new("postgresql://localhost/mydb").unwrap();
12//! ```
13
14mod statement_cache;
15#[cfg(feature = "tls")]
16mod tls;
17mod r#type;
18mod value;
19
20pub(crate) use value::Value;
21
22use async_trait::async_trait;
23use percent_encoding::percent_decode_str;
24use std::{borrow::Cow, sync::Arc};
25use toasty_core::{
26    Result, Schema,
27    driver::{Capability, Driver, ExecResponse, Operation},
28    schema::db::{self, Migration, SchemaDiff, Table},
29    stmt,
30    stmt::ValueRecord,
31};
32use toasty_sql::{self as sql, TypedValue};
33use tokio_postgres::{Client, Config, Socket, tls::MakeTlsConnect, types::ToSql};
34use url::Url;
35
36use crate::{statement_cache::StatementCache, r#type::TypeExt};
37
38/// A PostgreSQL [`Driver`] that connects via `tokio-postgres`.
39///
40/// # Examples
41///
42/// ```no_run
43/// use toasty_driver_postgresql::PostgreSQL;
44///
45/// let driver = PostgreSQL::new("postgresql://localhost/mydb").unwrap();
46/// ```
47#[derive(Debug)]
48pub struct PostgreSQL {
49    url: String,
50    config: Config,
51    #[cfg(feature = "tls")]
52    tls: Option<tls::MakeRustlsConnect>,
53}
54
55impl PostgreSQL {
56    /// Create a new PostgreSQL driver from a connection URL
57    pub fn new(url: impl Into<String>) -> Result<Self> {
58        let url_str = url.into();
59        let url = Url::parse(&url_str).map_err(toasty_core::Error::driver_operation_failed)?;
60
61        if !matches!(url.scheme(), "postgresql" | "postgres") {
62            return Err(toasty_core::Error::invalid_connection_url(format!(
63                "connection URL does not have a `postgresql` scheme; url={}",
64                url
65            )));
66        }
67
68        let host = url.host_str().ok_or_else(|| {
69            toasty_core::Error::invalid_connection_url(format!(
70                "missing host in connection URL; url={}",
71                url
72            ))
73        })?;
74
75        if url.path().is_empty() {
76            return Err(toasty_core::Error::invalid_connection_url(format!(
77                "no database specified - missing path in connection URL; url={}",
78                url
79            )));
80        }
81
82        let mut config = Config::new();
83        config.host(host);
84
85        let dbname = percent_decode_str(url.path().trim_start_matches('/'))
86            .decode_utf8()
87            .map_err(|_| {
88                toasty_core::Error::invalid_connection_url("database name is not valid UTF-8")
89            })?;
90        config.dbname(&*dbname);
91
92        if let Some(port) = url.port() {
93            config.port(port);
94        }
95
96        if !url.username().is_empty() {
97            let user = percent_decode_str(url.username())
98                .decode_utf8()
99                .map_err(|_| {
100                    toasty_core::Error::invalid_connection_url("username is not valid UTF-8")
101                })?;
102            config.user(&*user);
103        }
104
105        if let Some(password) = url.password() {
106            config.password(percent_decode_str(password).collect::<Vec<u8>>());
107        }
108
109        #[cfg(feature = "tls")]
110        let tls = tls::configure_tls(&url, &mut config)?;
111
112        #[cfg(not(feature = "tls"))]
113        for (key, value) in url.query_pairs() {
114            if key == "sslmode" && value != "disable" {
115                return Err(toasty_core::Error::invalid_connection_url(
116                    "TLS not available: compile with the `tls` feature",
117                ));
118            }
119        }
120
121        Ok(Self {
122            url: url_str,
123            config,
124            #[cfg(feature = "tls")]
125            tls,
126        })
127    }
128
129    async fn connect_with_config(&self, config: Config) -> Result<Connection> {
130        #[cfg(feature = "tls")]
131        if let Some(ref tls) = self.tls {
132            return Connection::connect(config, tls.clone()).await;
133        }
134        Connection::connect(config, tokio_postgres::NoTls).await
135    }
136}
137
138#[async_trait]
139impl Driver for PostgreSQL {
140    fn url(&self) -> Cow<'_, str> {
141        Cow::Borrowed(&self.url)
142    }
143
144    fn capability(&self) -> &'static Capability {
145        &Capability::POSTGRESQL
146    }
147
148    async fn connect(&self) -> toasty_core::Result<Box<dyn toasty_core::driver::Connection>> {
149        Ok(Box::new(
150            self.connect_with_config(self.config.clone()).await?,
151        ))
152    }
153
154    fn generate_migration(&self, schema_diff: &SchemaDiff<'_>) -> Migration {
155        let statements = sql::MigrationStatement::from_diff(schema_diff, &Capability::POSTGRESQL);
156
157        let sql_strings: Vec<String> = statements
158            .iter()
159            .map(|stmt| {
160                let mut params = Vec::<TypedValue>::new();
161                let sql = sql::Serializer::postgresql(stmt.schema())
162                    .serialize(stmt.statement(), &mut params);
163                assert!(
164                    params.is_empty(),
165                    "migration statements should not have parameters"
166                );
167                sql
168            })
169            .collect();
170
171        Migration::new_sql(sql_strings.join("\n"))
172    }
173
174    async fn reset_db(&self) -> toasty_core::Result<()> {
175        let dbname = self
176            .config
177            .get_dbname()
178            .ok_or_else(|| {
179                toasty_core::Error::invalid_connection_url("no database name configured")
180            })?
181            .to_string();
182
183        // We cannot drop a database we are currently connected to, so we need a temp database.
184        let temp_dbname = "__toasty_reset_temp";
185
186        let connect = |dbname: &str| {
187            let mut config = self.config.clone();
188            config.dbname(dbname);
189            self.connect_with_config(config)
190        };
191
192        // Step 1: Connect to the target DB and create a temp DB
193        let conn = connect(&dbname).await?;
194        conn.client
195            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
196            .await
197            .map_err(toasty_core::Error::driver_operation_failed)?;
198        conn.client
199            .execute(&format!("CREATE DATABASE \"{}\"", temp_dbname), &[])
200            .await
201            .map_err(toasty_core::Error::driver_operation_failed)?;
202        drop(conn);
203
204        // Step 2: Connect to the temp DB, drop and recreate the target
205        let conn = connect(temp_dbname).await?;
206        conn.client
207            .execute(
208                "SELECT pg_terminate_backend(pid) \
209                 FROM pg_stat_activity \
210                 WHERE datname = $1 AND pid <> pg_backend_pid()",
211                &[&dbname],
212            )
213            .await
214            .map_err(toasty_core::Error::driver_operation_failed)?;
215        conn.client
216            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", dbname), &[])
217            .await
218            .map_err(toasty_core::Error::driver_operation_failed)?;
219        conn.client
220            .execute(&format!("CREATE DATABASE \"{}\"", dbname), &[])
221            .await
222            .map_err(toasty_core::Error::driver_operation_failed)?;
223        drop(conn);
224
225        // Step 3: Connect back to the target and clean up the temp DB
226        let conn = connect(&dbname).await?;
227        conn.client
228            .execute(&format!("DROP DATABASE IF EXISTS \"{}\"", temp_dbname), &[])
229            .await
230            .map_err(toasty_core::Error::driver_operation_failed)?;
231
232        Ok(())
233    }
234}
235
236/// An open connection to a PostgreSQL database.
237#[derive(Debug)]
238pub struct Connection {
239    client: Client,
240    statement_cache: StatementCache,
241}
242
243impl Connection {
244    /// Initialize a Toasty PostgreSQL connection using an initialized client.
245    pub fn new(client: Client) -> Self {
246        Self {
247            client,
248            statement_cache: StatementCache::new(100),
249        }
250    }
251
252    /// Connects to a PostgreSQL database using a [`postgres::Config`].
253    ///
254    /// See [`postgres::Client::configure`] for more information.
255    pub async fn connect<T>(config: Config, tls: T) -> Result<Self>
256    where
257        T: MakeTlsConnect<Socket> + 'static,
258        T::Stream: Send,
259    {
260        let (client, connection) = config
261            .connect(tls)
262            .await
263            .map_err(toasty_core::Error::driver_operation_failed)?;
264
265        tokio::spawn(async move {
266            if let Err(e) = connection.await {
267                eprintln!("connection error: {e}");
268            }
269        });
270
271        Ok(Self::new(client))
272    }
273
274    /// Creates a table.
275    pub async fn create_table(&mut self, schema: &db::Schema, table: &Table) -> Result<()> {
276        let serializer = sql::Serializer::postgresql(schema);
277
278        let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
279        let sql = serializer.serialize(
280            &sql::Statement::create_table(table, &Capability::POSTGRESQL),
281            &mut params,
282        );
283
284        assert!(
285            params.is_empty(),
286            "creating a table shouldn't involve any parameters"
287        );
288
289        self.client
290            .execute(&sql, &[])
291            .await
292            .map_err(toasty_core::Error::driver_operation_failed)?;
293
294        // NOTE: `params` is guaranteed to be empty based on the assertion above. If
295        // that changes, `params.clear()` should be called here.
296        for index in &table.indices {
297            if index.primary_key {
298                continue;
299            }
300
301            let sql = serializer.serialize(&sql::Statement::create_index(index), &mut params);
302
303            assert!(
304                params.is_empty(),
305                "creating an index shouldn't involve any parameters"
306            );
307
308            self.client
309                .execute(&sql, &[])
310                .await
311                .map_err(toasty_core::Error::driver_operation_failed)?;
312        }
313
314        Ok(())
315    }
316}
317
318impl From<Client> for Connection {
319    fn from(client: Client) -> Self {
320        Self {
321            client,
322            statement_cache: StatementCache::new(100),
323        }
324    }
325}
326
327#[async_trait]
328impl toasty_core::driver::Connection for Connection {
329    async fn exec(&mut self, schema: &Arc<Schema>, op: Operation) -> Result<ExecResponse> {
330        tracing::trace!(driver = "postgresql", op = %op.name(), "driver exec");
331
332        if let Operation::Transaction(ref t) = op {
333            let sql = sql::Serializer::postgresql(&schema.db).serialize_transaction(t);
334            self.client.batch_execute(&sql).await.map_err(|e| {
335                if let Some(db_err) = e.as_db_error() {
336                    match db_err.code().code() {
337                        "40001" => toasty_core::Error::serialization_failure(db_err.message()),
338                        "25006" => toasty_core::Error::read_only_transaction(db_err.message()),
339                        _ => toasty_core::Error::driver_operation_failed(e),
340                    }
341                } else {
342                    toasty_core::Error::driver_operation_failed(e)
343                }
344            })?;
345            return Ok(ExecResponse::count(0));
346        }
347
348        let (sql, ret_tys): (sql::Statement, _) = match op {
349            Operation::Insert(op) => (op.stmt.into(), None),
350            Operation::QuerySql(query) => {
351                assert!(
352                    query.last_insert_id_hack.is_none(),
353                    "last_insert_id_hack is MySQL-specific and should not be set for PostgreSQL"
354                );
355                (query.stmt.into(), query.ret)
356            }
357            op => todo!("op={:#?}", op),
358        };
359
360        let width = sql.returning_len();
361
362        let mut params: Vec<toasty_sql::TypedValue> = Vec::new();
363        let sql_as_str = sql::Serializer::postgresql(&schema.db).serialize(&sql, &mut params);
364
365        tracing::debug!(db.system = "postgresql", db.statement = %sql_as_str, params = params.len(), "executing SQL");
366
367        let param_types = params
368            .iter()
369            .map(|typed_value| typed_value.infer_ty().to_postgres_type())
370            .collect::<Vec<_>>();
371
372        let values: Vec<_> = params.into_iter().map(|tv| Value::from(tv.value)).collect();
373        let params = values
374            .iter()
375            .map(|param| param as &(dyn ToSql + Sync))
376            .collect::<Vec<_>>();
377
378        let statement = self
379            .statement_cache
380            .prepare_typed(&mut self.client, &sql_as_str, &param_types)
381            .await
382            .map_err(toasty_core::Error::driver_operation_failed)?;
383
384        if width.is_none() {
385            let count = self
386                .client
387                .execute(&statement, &params)
388                .await
389                .map_err(toasty_core::Error::driver_operation_failed)?;
390            return Ok(ExecResponse::count(count));
391        }
392
393        let rows = self
394            .client
395            .query(&statement, &params)
396            .await
397            .map_err(toasty_core::Error::driver_operation_failed)?;
398
399        if width.is_none() {
400            let [row] = &rows[..] else { todo!() };
401            let total = row.get::<usize, i64>(0);
402            let condition_matched = row.get::<usize, i64>(1);
403
404            if total == condition_matched {
405                Ok(ExecResponse::count(total as _))
406            } else {
407                Err(toasty_core::Error::condition_failed(
408                    "update condition did not match",
409                ))
410            }
411        } else {
412            let ret_tys = ret_tys.as_ref().unwrap().clone();
413            let results = rows.into_iter().map(move |row| {
414                let mut results = Vec::new();
415                for (i, column) in row.columns().iter().enumerate() {
416                    results.push(Value::from_sql(i, &row, column, &ret_tys[i]).into_inner());
417                }
418
419                Ok(ValueRecord::from_vec(results))
420            });
421
422            Ok(ExecResponse::value_stream(stmt::ValueStream::from_iter(
423                results,
424            )))
425        }
426    }
427
428    async fn push_schema(&mut self, schema: &Schema) -> Result<()> {
429        for table in &schema.db.tables {
430            tracing::debug!(table = %table.name, "creating table");
431            self.create_table(&schema.db, table).await?;
432        }
433        Ok(())
434    }
435
436    async fn applied_migrations(
437        &mut self,
438    ) -> Result<Vec<toasty_core::schema::db::AppliedMigration>> {
439        // Ensure the migrations table exists
440        self.client
441            .execute(
442                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
443                id BIGINT PRIMARY KEY,
444                name TEXT NOT NULL,
445                applied_at TIMESTAMP NOT NULL
446            )",
447                &[],
448            )
449            .await
450            .map_err(toasty_core::Error::driver_operation_failed)?;
451
452        // Query all applied migrations
453        let rows = self
454            .client
455            .query(
456                "SELECT id FROM __toasty_migrations ORDER BY applied_at",
457                &[],
458            )
459            .await
460            .map_err(toasty_core::Error::driver_operation_failed)?;
461
462        Ok(rows
463            .iter()
464            .map(|row| {
465                let id: i64 = row.get(0);
466                toasty_core::schema::db::AppliedMigration::new(id as u64)
467            })
468            .collect())
469    }
470
471    async fn apply_migration(
472        &mut self,
473        id: u64,
474        name: &str,
475        migration: &toasty_core::schema::db::Migration,
476    ) -> Result<()> {
477        tracing::info!(id = id, name = %name, "applying migration");
478        // Ensure the migrations table exists
479        self.client
480            .execute(
481                "CREATE TABLE IF NOT EXISTS __toasty_migrations (
482                id BIGINT PRIMARY KEY,
483                name TEXT NOT NULL,
484                applied_at TIMESTAMP NOT NULL
485            )",
486                &[],
487            )
488            .await
489            .map_err(toasty_core::Error::driver_operation_failed)?;
490
491        // Start transaction
492        let transaction = self
493            .client
494            .transaction()
495            .await
496            .map_err(toasty_core::Error::driver_operation_failed)?;
497
498        // Execute each migration statement
499        for statement in migration.statements() {
500            if let Err(e) = transaction
501                .batch_execute(statement)
502                .await
503                .map_err(toasty_core::Error::driver_operation_failed)
504            {
505                transaction
506                    .rollback()
507                    .await
508                    .map_err(toasty_core::Error::driver_operation_failed)?;
509                return Err(e);
510            }
511        }
512
513        // Record the migration
514        if let Err(e) = transaction
515            .execute(
516                "INSERT INTO __toasty_migrations (id, name, applied_at) VALUES ($1, $2, NOW())",
517                &[&(id as i64), &name],
518            )
519            .await
520            .map_err(toasty_core::Error::driver_operation_failed)
521        {
522            transaction
523                .rollback()
524                .await
525                .map_err(toasty_core::Error::driver_operation_failed)?;
526            return Err(e);
527        }
528
529        // Commit transaction
530        transaction
531            .commit()
532            .await
533            .map_err(toasty_core::Error::driver_operation_failed)?;
534        Ok(())
535    }
536}