tmp_postgrust/
lib.rs

1/*!
2`tmp-postgrust` provides temporary postgresql processes that are cleaned up
3after being dropped.
4
5
6# Inspiration / Similar Projects
7- [tmp-postgres](https://github.com/jfischoff/tmp-postgres)
8- [testing.postgresql](https://github.com/tk0miya/testing.postgresql)
9*/
10#![deny(missing_docs)]
11#![warn(clippy::all, clippy::pedantic)]
12
13/// Methods for Asynchronous API
14#[cfg(feature = "tokio-process")]
15pub mod asynchronous;
16/// Common Errors
17pub mod errors;
18mod search;
19/// Methods for Synchronous API
20pub mod synchronous;
21
22use std::fs::{metadata, set_permissions};
23use std::io::{BufRead, BufReader};
24use std::path::Path;
25use std::sync::atomic::AtomicU32;
26use std::sync::{Arc, Mutex, OnceLock};
27use std::{fs::File, io::Write};
28
29use ctor::dtor;
30use nix::unistd::{Gid, Uid};
31use tempfile::{Builder, TempDir};
32use tracing::{debug, info, instrument};
33
34use crate::errors::{TmpPostgrustError, TmpPostgrustResult};
35
36const TMP_POSTGRUST_DB_NAME: &str = "tmp-postgrust";
37const TMP_POSTGRUST_USER_NAME: &str = "tmp-postgrust-user";
38
39pub(crate) static POSTGRES_UID_GID: OnceLock<(Uid, Gid)> = OnceLock::new();
40
41/// As the static variables declared by this crate contain values that
42/// need to be dropped at program exit to clean up resources, we use a
43/// `#[dtor]` hack to drop the variables if they have been initialized.
44#[dtor]
45fn cleanup_static() {
46    #[cfg(feature = "tokio-process")]
47    if let Some(factory_mutex) = TOKIO_POSTGRES_FACTORY.get() {
48        let mut guard = factory_mutex.blocking_lock();
49        drop(guard.take());
50    }
51
52    if let Some(factory_mutex) = DEFAULT_POSTGRES_FACTORY.get() {
53        let mut guard = factory_mutex
54            .lock()
55            .expect("Failed to lock default factory mutex.");
56        drop(guard.take());
57    }
58}
59
60static DEFAULT_POSTGRES_FACTORY: OnceLock<Mutex<Option<TmpPostgrustFactory>>> = OnceLock::new();
61
62/// Create a new default instance, initializing the `DEFAULT_POSTGRES_FACTORY` if it
63/// does not already exist.
64///
65/// # Errors
66///
67/// Will return `Err` if postgres is not installed on system
68///
69/// # Panics
70///
71/// Will panic if a `TmpPostgrustFactory::try_new` returns an error the first time the function
72/// is called.
73pub fn new_default_process() -> TmpPostgrustResult<synchronous::ProcessGuard> {
74    let factory_mutex = DEFAULT_POSTGRES_FACTORY.get_or_init(|| {
75        Mutex::new(Some(
76            TmpPostgrustFactory::try_new().expect("Failed to initialize default postgres factory."),
77        ))
78    });
79    let guard = factory_mutex
80        .lock()
81        .expect("Failed to lock default factory mutex.");
82    let factory = guard
83        .as_ref()
84        .expect("Default factory is uninitialized or has been dropped.");
85    factory.new_instance()
86}
87
88/// Create a new default instance, initializing the `DEFAULT_POSTGRES_FACTORY` if it
89/// does not already exist. The function passed as the `migrate` parameters
90/// will be run the first time the factory is initialised.
91///
92/// # Errors
93///
94/// Will return `Err` if postgres is not installed on system
95///
96/// # Panics
97///
98/// Will panic if a `TmpPostgrustFactory::try_new` returns an error the first time the function
99/// is called.
100pub fn new_default_process_with_migrations(
101    migrate: impl Fn(&str) -> Result<(), Box<dyn std::error::Error + Send + Sync>>,
102) -> TmpPostgrustResult<synchronous::ProcessGuard> {
103    let factory_mutex = DEFAULT_POSTGRES_FACTORY.get_or_init(|| {
104        let factory =
105            TmpPostgrustFactory::try_new().expect("Failed to initialize default postgres factory.");
106        factory
107            .run_migrations(migrate)
108            .expect("Failed to run migrations");
109
110        Mutex::new(Some(factory))
111    });
112    let guard = factory_mutex
113        .lock()
114        .expect("Failed to lock default factory mutex.");
115    let factory = guard
116        .as_ref()
117        .expect("Default factory is uninitialized or has been dropped.");
118    factory.new_instance()
119}
120
121/// Static factory that can be re-used between tests.
122#[cfg(feature = "tokio-process")]
123static TOKIO_POSTGRES_FACTORY: tokio::sync::OnceCell<
124    tokio::sync::Mutex<Option<TmpPostgrustFactory>>,
125> = tokio::sync::OnceCell::const_new();
126
127/// Create a new default instance, initializing the `TOKIO_POSTGRES_FACTORY` if it
128/// does not already exist.
129///
130/// # Errors
131///
132/// Will return `Err` if postgres is not installed on system
133///
134/// # Panics
135///
136/// Will panic if a `TmpPostgrustFactory::try_new_async` returns an error the first time the function
137/// is called.
138#[cfg(feature = "tokio-process")]
139pub async fn new_default_process_async() -> TmpPostgrustResult<asynchronous::ProcessGuard> {
140    let factory_mutex = TOKIO_POSTGRES_FACTORY
141        .get_or_try_init(|| async {
142            TmpPostgrustFactory::try_new_async()
143                .await
144                .map(|factory| tokio::sync::Mutex::new(Some(factory)))
145        })
146        .await?;
147    let guard = factory_mutex.lock().await;
148    let factory = guard
149        .as_ref()
150        .expect("Default tokio factory is uninitialized or has been dropped.");
151    factory.new_instance_async().await
152}
153
154/// Create a new default instance, initializing the `TOKIO_POSTGRES_FACTORY` if it
155/// does not already exist. The function passed as the `migrate` parameters
156/// will be run the first time the factory is initialised.
157///
158/// # Errors
159///
160/// Will return `Err` if postgres is not installed on system
161///
162/// # Panics
163///
164/// Will panic if a `TmpPostgrustFactory::try_new_async` returns an error the first time the function
165/// is called.
166#[cfg(feature = "tokio-process")]
167pub async fn new_default_process_async_with_migrations(
168    migrate: impl Fn(&str) -> Result<(), Box<dyn std::error::Error + Send + Sync>>,
169) -> TmpPostgrustResult<asynchronous::ProcessGuard> {
170    let factory_mutex = TOKIO_POSTGRES_FACTORY
171        .get_or_try_init(|| async {
172            TmpPostgrustFactory::try_new_async().await.map(|factory| {
173                factory
174                    .run_migrations(migrate)
175                    .expect("Failed to run migrations.");
176                tokio::sync::Mutex::new(Some(factory))
177            })
178        })
179        .await?;
180    let guard = factory_mutex.lock().await;
181    let factory = guard
182        .as_ref()
183        .expect("Default tokio factory is uninitialized or has been dropped.");
184    factory.new_instance_async().await
185}
186
187/// Factory for creating new temporary postgresql processes.
188#[derive(Debug)]
189pub struct TmpPostgrustFactory {
190    socket_dir: Arc<TempDir>,
191    cache_dir: Arc<TempDir>,
192    config: String,
193    next_port: AtomicU32,
194}
195
196impl TmpPostgrustFactory {
197    /// Build a Postgresql configuration for temporary databases as a String.
198    fn build_config(socket_dir: &Path) -> String {
199        let mut config = String::new();
200        // Minimize chance of running out of shared memory
201        config.push_str("shared_buffers = '12MB'\n");
202        // Disable TCP connections.
203        config.push_str("listen_addresses = ''\n");
204        // Listen on UNIX socket.
205        config.push_str(&format!(
206            "unix_socket_directories = \'{}\'\n",
207            socket_dir.to_str().unwrap()
208        ));
209
210        config
211    }
212
213    /// Try to create a new factory by creating temporary directories and the necessary config.
214    #[instrument]
215    pub fn try_new() -> TmpPostgrustResult<TmpPostgrustFactory> {
216        let socket_dir = Builder::new()
217            .prefix("tmp-postgrust-socket")
218            .tempdir()
219            .map_err(TmpPostgrustError::CreateSocketDirFailed)?;
220        let cache_dir = Builder::new()
221            .prefix("tmp-postgrust-cache")
222            .tempdir()
223            .map_err(TmpPostgrustError::CreateCacheDirFailed)?;
224
225        synchronous::chown_to_non_root(cache_dir.path())?;
226        synchronous::chown_to_non_root(socket_dir.path())?;
227        synchronous::exec_init_db(cache_dir.path())?;
228
229        let config = TmpPostgrustFactory::build_config(socket_dir.path());
230
231        let factory = TmpPostgrustFactory {
232            socket_dir: Arc::new(socket_dir),
233            cache_dir: Arc::new(cache_dir),
234            config,
235            next_port: AtomicU32::new(5432),
236        };
237        let process = factory.start_postgresql(&factory.cache_dir)?;
238        synchronous::exec_create_user(process.socket_dir.path(), process.port, &process.user_name)?;
239        synchronous::exec_create_db(
240            process.socket_dir.path(),
241            process.port,
242            &process.user_name,
243            &process.db_name,
244        )?;
245
246        Ok(factory)
247    }
248
249    /// Try to create a new factory by creating temporary directories and the necessary config.
250    #[cfg(feature = "tokio-process")]
251    #[instrument]
252    pub async fn try_new_async() -> TmpPostgrustResult<TmpPostgrustFactory> {
253        let socket_dir = Builder::new()
254            .prefix("tmp-postgrust-socket")
255            .tempdir()
256            .map_err(TmpPostgrustError::CreateSocketDirFailed)?;
257        let cache_dir = Builder::new()
258            .prefix("tmp-postgrust-cache")
259            .tempdir()
260            .map_err(TmpPostgrustError::CreateCacheDirFailed)?;
261
262        asynchronous::chown_to_non_root(cache_dir.path()).await?;
263        asynchronous::chown_to_non_root(socket_dir.path()).await?;
264        asynchronous::exec_init_db(cache_dir.path()).await?;
265
266        let config = TmpPostgrustFactory::build_config(socket_dir.path());
267
268        let factory = TmpPostgrustFactory {
269            socket_dir: Arc::new(socket_dir),
270            cache_dir: Arc::new(cache_dir),
271            config,
272            next_port: AtomicU32::new(5432),
273        };
274        let process = factory.start_postgresql_async(&factory.cache_dir).await?;
275        asynchronous::exec_create_user(process.socket_dir.path(), process.port, &process.user_name)
276            .await?;
277        asynchronous::exec_create_db(
278            process.socket_dir.path(),
279            process.port,
280            &process.user_name,
281            &process.db_name,
282        )
283        .await?;
284
285        Ok(factory)
286    }
287
288    /// Run migrations against the cache directory, will cause all subsequent instances
289    /// to be run against a version of the database where the migrations have been applied.
290    ///
291    /// # Errors
292    ///
293    /// Will error if Postgresql is unable to start or if the migrate function returns
294    /// an error.
295    pub fn run_migrations(
296        &self,
297        migrate: impl FnOnce(&str) -> Result<(), Box<dyn std::error::Error + Send + Sync>>,
298    ) -> TmpPostgrustResult<()> {
299        let process = self.start_postgresql(&self.cache_dir)?;
300
301        migrate(&process.connection_string()).map_err(TmpPostgrustError::MigrationsFailed)?;
302
303        Ok(())
304    }
305
306    /// Start a new postgresql instance and return a process guard that will ensure it is cleaned
307    /// up when dropped.
308    #[instrument(skip(self))]
309    pub fn new_instance(&self) -> TmpPostgrustResult<synchronous::ProcessGuard> {
310        let data_directory = Builder::new()
311            .prefix("tmp-postgrust-db")
312            .tempdir()
313            .map_err(TmpPostgrustError::CreateDataDirFailed)?;
314        let data_directory_path = data_directory.path();
315
316        set_permissions(
317            &data_directory,
318            metadata(self.cache_dir.path()).unwrap().permissions(),
319        )
320        .unwrap();
321        synchronous::exec_copy_dir(self.cache_dir.path(), data_directory_path)?;
322
323        if !data_directory_path.join("PG_VERSION").exists() {
324            return Err(TmpPostgrustError::EmptyDataDirectory);
325        };
326
327        self.start_postgresql(&Arc::new(data_directory))
328    }
329
330    #[instrument(skip(self))]
331    fn start_postgresql(
332        &self,
333        dir: &Arc<TempDir>,
334    ) -> TmpPostgrustResult<synchronous::ProcessGuard> {
335        File::create(dir.path().join("postgresql.conf"))
336            .map_err(TmpPostgrustError::CreateConfigFailed)?
337            .write_all(self.config.as_bytes())
338            .map_err(TmpPostgrustError::CreateConfigFailed)?;
339
340        let port = self
341            .next_port
342            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
343
344        synchronous::chown_to_non_root(dir.path())?;
345        let mut postgres_process_handle = synchronous::start_postgres_subprocess(dir.path(), port)?;
346        let stdout = postgres_process_handle.stdout.take().unwrap();
347        let stderr = postgres_process_handle.stderr.take().unwrap();
348
349        let stdout_reader = BufReader::new(stdout).lines();
350        let mut stderr_reader = BufReader::new(stderr).lines();
351
352        while let Some(Ok(line)) = stderr_reader.next() {
353            debug!("Postgresql: {}", line);
354            if line.contains("database system is ready to accept connections") {
355                info!("temporary database system is read to accept connections");
356                break;
357            }
358        }
359
360        Ok(synchronous::ProcessGuard {
361            stdout_reader: Some(stdout_reader),
362            stderr_reader: Some(stderr_reader),
363            port,
364            db_name: TMP_POSTGRUST_DB_NAME.to_string(),
365            user_name: TMP_POSTGRUST_USER_NAME.to_string(),
366            postgres_process: postgres_process_handle,
367            _data_directory: Arc::clone(dir),
368            _cache_directory: Arc::clone(&self.cache_dir),
369            socket_dir: Arc::clone(&self.socket_dir),
370        })
371    }
372
373    /// Start a new postgresql instance and return a process guard that will ensure it is cleaned
374    /// up when dropped.
375    #[cfg(feature = "tokio-process")]
376    #[instrument(skip(self))]
377    pub async fn new_instance_async(&self) -> TmpPostgrustResult<asynchronous::ProcessGuard> {
378        use tokio::fs::{metadata, set_permissions};
379
380        let data_directory = Builder::new()
381            .prefix("tmp-postgrust-db")
382            .tempdir()
383            .map_err(TmpPostgrustError::CreateDataDirFailed)?;
384        let data_directory_path = data_directory.path();
385
386        set_permissions(
387            &data_directory,
388            metadata(self.cache_dir.path()).await.unwrap().permissions(),
389        )
390        .await
391        .unwrap();
392        asynchronous::exec_copy_dir(self.cache_dir.path(), data_directory_path).await?;
393
394        if !data_directory_path.join("PG_VERSION").exists() {
395            return Err(TmpPostgrustError::EmptyDataDirectory);
396        };
397
398        self.start_postgresql_async(&Arc::new(data_directory)).await
399    }
400
401    #[cfg(feature = "tokio-process")]
402    #[instrument(skip(self))]
403    async fn start_postgresql_async(
404        &self,
405        dir: &Arc<TempDir>,
406    ) -> TmpPostgrustResult<asynchronous::ProcessGuard> {
407        use tokio::io::AsyncBufReadExt;
408
409        let process_permit = asynchronous::MAX_CONCURRENT_PROCESSES
410            .acquire()
411            .await
412            .unwrap();
413
414        File::create(dir.path().join("postgresql.conf"))
415            .map_err(TmpPostgrustError::CreateConfigFailed)?
416            .write_all(self.config.as_bytes())
417            .map_err(TmpPostgrustError::CreateConfigFailed)?;
418
419        let port = self
420            .next_port
421            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
422
423        asynchronous::chown_to_non_root(dir.path()).await?;
424        let mut postgres_process_handle =
425            asynchronous::start_postgres_subprocess(dir.path(), port)?;
426        let stdout = postgres_process_handle.stdout.take().unwrap();
427        let stderr = postgres_process_handle.stderr.take().unwrap();
428
429        let stdout_reader = tokio::io::BufReader::new(stdout).lines();
430        let mut stderr_reader = tokio::io::BufReader::new(stderr).lines();
431
432        while let Some(line) = stderr_reader.next_line().await.unwrap() {
433            debug!("Postgresql: {}", line);
434            if line.contains("database system is ready to accept connections") {
435                info!("temporary database system is read to accept connections");
436                break;
437            }
438        }
439
440        Ok(asynchronous::ProcessGuard {
441            stdout_reader: Some(stdout_reader),
442            stderr_reader: Some(stderr_reader),
443            port,
444            db_name: TMP_POSTGRUST_DB_NAME.to_string(),
445            user_name: TMP_POSTGRUST_USER_NAME.to_string(),
446            _data_directory: Arc::clone(dir),
447            _cache_directory: Arc::clone(&self.cache_dir),
448            socket_dir: Arc::clone(&self.socket_dir),
449            postgres_process: postgres_process_handle,
450            _process_permit: process_permit,
451        })
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    use test_log::test;
460    use tokio_postgres::NoTls;
461    use tracing::error;
462
463    #[test(tokio::test)]
464    async fn it_works() {
465        let factory = TmpPostgrustFactory::try_new().expect("failed to create factory");
466
467        let postgresql_proc = factory
468            .new_instance()
469            .expect("failed to create a new instance");
470
471        let (client, conn) = tokio_postgres::connect(&postgresql_proc.connection_string(), NoTls)
472            .await
473            .expect("failed to connect to postgresql");
474
475        tokio::spawn(async move {
476            if let Err(e) = conn.await {
477                error!("connection error: {}", e);
478            }
479        });
480
481        client.query("SELECT 1;", &[]).await.unwrap();
482    }
483
484    #[cfg(feature = "tokio-process")]
485    #[test(tokio::test)]
486    async fn it_works_async() {
487        let factory = TmpPostgrustFactory::try_new_async()
488            .await
489            .expect("failed to create factory");
490
491        let postgresql_proc = factory
492            .new_instance_async()
493            .await
494            .expect("failed to create a new instance");
495
496        let (client, conn) = tokio_postgres::connect(&postgresql_proc.connection_string(), NoTls)
497            .await
498            .expect("failed to connect to postgresql");
499
500        tokio::spawn(async move {
501            if let Err(e) = conn.await {
502                error!("connection error: {}", e);
503            }
504        });
505
506        client.query("SELECT 1;", &[]).await.unwrap();
507    }
508
509    #[test(tokio::test)]
510    async fn two_simulatenous_processes() {
511        let factory = TmpPostgrustFactory::try_new().expect("failed to create factory");
512
513        let proc1 = factory
514            .new_instance()
515            .expect("failed to create a new instance");
516
517        let proc2 = factory
518            .new_instance()
519            .expect("failed to create a new instance");
520
521        let (client1, conn1) = tokio_postgres::connect(&proc1.connection_string(), NoTls)
522            .await
523            .expect("failed to connect to postgresql");
524
525        tokio::spawn(async move {
526            if let Err(e) = conn1.await {
527                error!("connection error: {}", e);
528            }
529        });
530
531        let (client2, conn2) = tokio_postgres::connect(&proc2.connection_string(), NoTls)
532            .await
533            .expect("failed to connect to postgresql");
534
535        tokio::spawn(async move {
536            if let Err(e) = conn2.await {
537                error!("connection error: {}", e);
538            }
539        });
540
541        client1.query("SELECT 1;", &[]).await.unwrap();
542        client2.query("SELECT 1;", &[]).await.unwrap();
543    }
544
545    #[cfg(feature = "tokio-process")]
546    #[test(tokio::test)]
547    async fn two_simulatenous_processes_async() {
548        let factory = TmpPostgrustFactory::try_new_async()
549            .await
550            .expect("failed to create factory");
551
552        let proc1 = factory
553            .new_instance_async()
554            .await
555            .expect("failed to create a new instance");
556
557        let proc2 = factory
558            .new_instance_async()
559            .await
560            .expect("failed to create a new instance");
561
562        let (client1, conn1) = tokio_postgres::connect(&proc1.connection_string(), NoTls)
563            .await
564            .expect("failed to connect to postgresql");
565
566        tokio::spawn(async move {
567            if let Err(e) = conn1.await {
568                error!("connection error: {}", e);
569            }
570        });
571
572        let (client2, conn2) = tokio_postgres::connect(&proc2.connection_string(), NoTls)
573            .await
574            .expect("failed to connect to postgresql");
575
576        tokio::spawn(async move {
577            if let Err(e) = conn2.await {
578                error!("connection error: {}", e);
579            }
580        });
581
582        client1.query("SELECT 1;", &[]).await.unwrap();
583        client2.query("SELECT 1;", &[]).await.unwrap();
584    }
585
586    #[cfg(feature = "tokio-process")]
587    #[test(tokio::test)]
588    async fn default_process_factory_1() {
589        let proc = new_default_process_async().await.unwrap();
590
591        let (client, conn) = tokio_postgres::connect(&proc.connection_string(), NoTls)
592            .await
593            .expect("failed to connect to postgresql");
594
595        tokio::spawn(async move {
596            if let Err(e) = conn.await {
597                error!("connection error: {}", e);
598            }
599        });
600
601        // Chance to catch concurrent tests or database that have already been used.
602        client.execute("CREATE TABLE lock ();", &[]).await.unwrap();
603    }
604
605    #[cfg(feature = "tokio-process")]
606    #[test(tokio::test)]
607    async fn default_process_factory_2() {
608        let proc = new_default_process_async().await.unwrap();
609
610        let (client, conn) = tokio_postgres::connect(&proc.connection_string(), NoTls)
611            .await
612            .expect("failed to connect to postgresql");
613
614        tokio::spawn(async move {
615            if let Err(e) = conn.await {
616                error!("connection error: {}", e);
617            }
618        });
619
620        // Chance to catch concurrent tests or database that have already been used.
621        client.execute("CREATE TABLE lock ();", &[]).await.unwrap();
622    }
623
624    #[cfg(feature = "tokio-process")]
625    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
626    async fn default_process_factory_multithread_1() {
627        let proc = new_default_process_async().await.unwrap();
628
629        let (client, conn) = tokio_postgres::connect(&proc.connection_string(), NoTls)
630            .await
631            .expect("failed to connect to postgresql");
632
633        tokio::spawn(async move {
634            if let Err(e) = conn.await {
635                error!("connection error: {}", e);
636            }
637        });
638
639        // Chance to catch concurrent tests or database that have already been used.
640        client.execute("CREATE TABLE lock ();", &[]).await.unwrap();
641    }
642
643    #[cfg(feature = "tokio-process")]
644    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
645    async fn default_process_factory_multithread_2() {
646        let proc = new_default_process_async().await.unwrap();
647
648        let (client, conn) = tokio_postgres::connect(&proc.connection_string(), NoTls)
649            .await
650            .expect("failed to connect to postgresql");
651
652        tokio::spawn(async move {
653            if let Err(e) = conn.await {
654                error!("connection error: {}", e);
655            }
656        });
657
658        // Chance to catch concurrent tests or database that have already been used.
659        client.execute("CREATE TABLE lock ();", &[]).await.unwrap();
660    }
661}