sqlx_postgres/listener.rs
1use std::fmt::{self, Debug};
2use std::io;
3use std::str::from_utf8;
4
5use futures_channel::mpsc;
6use futures_core::future::BoxFuture;
7use futures_core::stream::{BoxStream, Stream};
8use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
9use sqlx_core::acquire::Acquire;
10use sqlx_core::sql_str::{AssertSqlSafe, SqlStr};
11use sqlx_core::transaction::Transaction;
12use sqlx_core::Either;
13use tracing::Instrument;
14
15use crate::describe::Describe;
16use crate::error::Error;
17use crate::executor::{Execute, Executor};
18use crate::message::{BackendMessageFormat, Notification};
19use crate::pool::PoolOptions;
20use crate::pool::{Pool, PoolConnection};
21use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
22
23/// A stream of asynchronous notifications from Postgres.
24///
25/// This listener will auto-reconnect. If the active
26/// connection being used ever dies, this listener will detect that event, create a
27/// new connection, will re-subscribe to all of the originally specified channels, and will resume
28/// operations as normal.
29pub struct PgListener {
30    pool: Pool<Postgres>,
31    connection: Option<PoolConnection<Postgres>>,
32    buffer_rx: mpsc::UnboundedReceiver<Notification>,
33    buffer_tx: Option<mpsc::UnboundedSender<Notification>>,
34    channels: Vec<String>,
35    ignore_close_event: bool,
36    eager_reconnect: bool,
37}
38
39/// An asynchronous notification from Postgres.
40pub struct PgNotification(Notification);
41
42impl PgListener {
43    pub async fn connect(url: &str) -> Result<Self, Error> {
44        // Create a pool of 1 without timeouts (as they don't apply here)
45        // We only use the pool to handle re-connections
46        let pool = PoolOptions::<Postgres>::new()
47            .max_connections(1)
48            .max_lifetime(None)
49            .idle_timeout(None)
50            .connect(url)
51            .await?;
52
53        let mut this = Self::connect_with(&pool).await?;
54        // We don't need to handle close events
55        this.ignore_close_event = true;
56
57        Ok(this)
58    }
59
60    pub async fn connect_with(pool: &Pool<Postgres>) -> Result<Self, Error> {
61        // Pull out an initial connection
62        let mut connection = pool.acquire().await?;
63
64        // Setup a notification buffer
65        let (sender, receiver) = mpsc::unbounded();
66        connection.inner.stream.notifications = Some(sender);
67
68        Ok(Self {
69            pool: pool.clone(),
70            connection: Some(connection),
71            buffer_rx: receiver,
72            buffer_tx: None,
73            channels: Vec::new(),
74            ignore_close_event: false,
75            eager_reconnect: true,
76        })
77    }
78
79    /// Set whether or not to ignore [`Pool::close_event()`]. Defaults to `false`.
80    ///
81    /// By default, when [`Pool::close()`] is called on the pool this listener is using
82    /// while [`Self::recv()`] or [`Self::try_recv()`] are waiting for a message, the wait is
83    /// cancelled and `Err(PoolClosed)` is returned.
84    ///
85    /// This is because `Pool::close()` will wait until _all_ connections are returned and closed,
86    /// including the one being used by this listener.
87    ///
88    /// Otherwise, `pool.close().await` would have to wait until `PgListener` encountered a
89    /// need to acquire a new connection (timeout, error, etc.) and dropped the one it was
90    /// currently holding, at which point `.recv()` or `.try_recv()` would return `Err(PoolClosed)`
91    /// on the attempt to acquire a new connection anyway.
92    ///
93    /// However, if you want `PgListener` to ignore the close event and continue waiting for a
94    /// message as long as it can, set this to `true`.
95    ///
96    /// Does nothing if this was constructed with [`PgListener::connect()`], as that creates an
97    /// internal pool just for the new instance of `PgListener` which cannot be closed manually.
98    pub fn ignore_pool_close_event(&mut self, val: bool) {
99        self.ignore_close_event = val;
100    }
101
102    /// Set whether a lost connection in `try_recv()` should be re-established before it returns
103    /// `Ok(None)`, or on the next call to `try_recv()`.
104    ///
105    /// By default, this is `true` and the connection is re-established before returning `Ok(None)`.
106    ///
107    /// If this is set to `false` then notifications will continue to be lost until the next call
108    /// to `try_recv()`. If your recovery logic uses a different database connection then
109    /// notifications that occur after it completes may be lost without any way to tell that they
110    /// have been.
111    pub fn eager_reconnect(&mut self, val: bool) {
112        self.eager_reconnect = val;
113    }
114
115    /// Starts listening for notifications on a channel.
116    /// The channel name is quoted here to ensure case sensitivity.
117    pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
118        self.connection()
119            .await?
120            .execute(AssertSqlSafe(format!(r#"LISTEN "{}""#, ident(channel))))
121            .await?;
122
123        self.channels.push(channel.to_owned());
124
125        Ok(())
126    }
127
128    /// Starts listening for notifications on all channels.
129    pub async fn listen_all(
130        &mut self,
131        channels: impl IntoIterator<Item = &str>,
132    ) -> Result<(), Error> {
133        let beg = self.channels.len();
134        self.channels.extend(channels.into_iter().map(|s| s.into()));
135
136        let query = build_listen_all_query(&self.channels[beg..]);
137        self.connection()
138            .await?
139            .execute(AssertSqlSafe(query))
140            .await?;
141
142        Ok(())
143    }
144
145    /// Stops listening for notifications on a channel.
146    /// The channel name is quoted here to ensure case sensitivity.
147    pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
148        // use RAW connection and do NOT re-connect automatically, since this is not required for
149        // UNLISTEN (we've disconnected anyways)
150        if let Some(connection) = self.connection.as_mut() {
151            connection
152                .execute(AssertSqlSafe(format!(r#"UNLISTEN "{}""#, ident(channel))))
153                .await?;
154        }
155
156        if let Some(pos) = self.channels.iter().position(|s| s == channel) {
157            self.channels.remove(pos);
158        }
159
160        Ok(())
161    }
162
163    /// Stops listening for notifications on all channels.
164    pub async fn unlisten_all(&mut self) -> Result<(), Error> {
165        // use RAW connection and do NOT re-connect automatically, since this is not required for
166        // UNLISTEN (we've disconnected anyways)
167        if let Some(connection) = self.connection.as_mut() {
168            connection.execute("UNLISTEN *").await?;
169        }
170
171        self.channels.clear();
172
173        Ok(())
174    }
175
176    #[inline]
177    async fn connect_if_needed(&mut self) -> Result<(), Error> {
178        if self.connection.is_none() {
179            let mut connection = self.pool.acquire().await?;
180            connection.inner.stream.notifications = self.buffer_tx.take();
181
182            connection
183                .execute(AssertSqlSafe(build_listen_all_query(&self.channels)))
184                .await?;
185
186            self.connection = Some(connection);
187        }
188
189        Ok(())
190    }
191
192    #[inline]
193    async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
194        // Ensure we have an active connection to work with.
195        self.connect_if_needed().await?;
196
197        Ok(self.connection.as_mut().unwrap())
198    }
199
200    /// Receives the next notification available from any of the subscribed channels.
201    ///
202    /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next
203    /// call to `recv()`, and should be entirely transparent (as long as it was just an
204    /// intermittent network failure or long-lived connection reaper).
205    ///
206    /// As notifications are transient, any received while the connection was lost, will not
207    /// be returned. If you'd prefer the reconnection to be explicit and have a chance to
208    /// do something before, please see [`try_recv`](Self::try_recv).
209    ///
210    /// # Example
211    ///
212    /// ```rust,no_run
213    /// # use sqlx::postgres::PgListener;
214    /// #
215    /// # sqlx::__rt::test_block_on(async move {
216    /// let mut listener = PgListener::connect("postgres:// ...").await?;
217    /// loop {
218    ///     // ask for next notification, re-connecting (transparently) if needed
219    ///     let notification = listener.recv().await?;
220    ///
221    ///     // handle notification, do something interesting
222    /// }
223    /// # Result::<(), sqlx::Error>::Ok(())
224    /// # }).unwrap();
225    /// ```
226    pub async fn recv(&mut self) -> Result<PgNotification, Error> {
227        loop {
228            if let Some(notification) = self.try_recv().await? {
229                return Ok(notification);
230            }
231        }
232    }
233
234    /// Receives the next notification available from any of the subscribed channels.
235    ///
236    /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is
237    /// reconnected either immediately, or on the next call to `try_recv()`, depending on
238    /// the value of [`eager_reconnect`].
239    ///
240    /// # Example
241    ///
242    /// ```rust,no_run
243    /// # use sqlx::postgres::PgListener;
244    /// #
245    /// # sqlx::__rt::test_block_on(async move {
246    /// # let mut listener = PgListener::connect("postgres:// ...").await?;
247    /// loop {
248    ///     // start handling notifications, connecting if needed
249    ///     while let Some(notification) = listener.try_recv().await? {
250    ///         // handle notification
251    ///     }
252    ///
253    ///     // connection lost, do something interesting
254    /// }
255    /// # Result::<(), sqlx::Error>::Ok(())
256    /// # }).unwrap();
257    /// ```
258    ///
259    /// [`eager_reconnect`]: PgListener::eager_reconnect
260    pub async fn try_recv(&mut self) -> Result<Option<PgNotification>, Error> {
261        // Flush the buffer first, if anything
262        // This would only fill up if this listener is used as a connection
263        if let Some(notification) = self.next_buffered() {
264            return Ok(Some(notification));
265        }
266
267        // Fetch our `CloseEvent` listener, if applicable.
268        let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
269
270        loop {
271            let next_message = self.connection().await?.inner.stream.recv_unchecked();
272
273            let res = if let Some(ref mut close_event) = close_event {
274                // cancels the wait and returns `Err(PoolClosed)` if the pool is closed
275                // before `next_message` returns, or if the pool was already closed
276                close_event.do_until(next_message).await?
277            } else {
278                next_message.await
279            };
280
281            let message = match res {
282                Ok(message) => message,
283
284                // The connection is dead, ensure that it is dropped,
285                // update self state, and loop to try again.
286                Err(Error::Io(err))
287                    if matches!(
288                        err.kind(),
289                        io::ErrorKind::ConnectionAborted |
290                        io::ErrorKind::UnexpectedEof |
291                        // see ERRORS section in tcp(7) man page (https://man7.org/linux/man-pages/man7/tcp.7.html)
292                        io::ErrorKind::TimedOut |
293                        io::ErrorKind::BrokenPipe
294                    ) =>
295                {
296                    if let Some(mut conn) = self.connection.take() {
297                        self.buffer_tx = conn.inner.stream.notifications.take();
298                        // Close the connection in a background task, so we can continue.
299                        conn.close_on_drop();
300                    }
301
302                    if self.eager_reconnect {
303                        self.connect_if_needed().await?;
304                    }
305
306                    // lost connection
307                    return Ok(None);
308                }
309
310                // Forward other errors
311                Err(error) => {
312                    return Err(error);
313                }
314            };
315
316            match message.format {
317                // We've received an async notification, return it.
318                BackendMessageFormat::NotificationResponse => {
319                    return Ok(Some(PgNotification(message.decode()?)));
320                }
321
322                // Mark the connection as ready for another query
323                BackendMessageFormat::ReadyForQuery => {
324                    self.connection().await?.inner.pending_ready_for_query_count -= 1;
325                }
326
327                // Ignore unexpected messages
328                _ => {}
329            }
330        }
331    }
332
333    /// Receives the next notification that already exists in the connection buffer, if any.
334    ///
335    /// This is similar to `try_recv`, except it will not wait if the connection has not yet received a notification.
336    ///
337    /// This is helpful if you want to retrieve all buffered notifications and process them in batches.
338    pub fn next_buffered(&mut self) -> Option<PgNotification> {
339        if let Ok(Some(notification)) = self.buffer_rx.try_next() {
340            Some(PgNotification(notification))
341        } else {
342            None
343        }
344    }
345
346    /// Consume this listener, returning a `Stream` of notifications.
347    ///
348    /// The backing connection will be automatically reconnected should it be lost.
349    ///
350    /// This has the same potential drawbacks as [`recv`](PgListener::recv).
351    ///
352    pub fn into_stream(mut self) -> impl Stream<Item = Result<PgNotification, Error>> + Unpin {
353        Box::pin(try_stream! {
354            loop {
355                r#yield!(self.recv().await?);
356            }
357        })
358    }
359}
360
361impl Drop for PgListener {
362    fn drop(&mut self) {
363        if let Some(mut conn) = self.connection.take() {
364            let fut = async move {
365                let _ = conn.execute("UNLISTEN *").await;
366
367                // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task
368                // otherwise, it may trigger a panic if this task is dropped because the runtime is going away:
369                // https://github.com/launchbadge/sqlx/issues/1389
370                conn.return_to_pool().await;
371            };
372
373            // Unregister any listeners before returning the connection to the pool.
374            crate::rt::spawn(fut.in_current_span());
375        }
376    }
377}
378
379impl<'c> Acquire<'c> for &'c mut PgListener {
380    type Database = Postgres;
381    type Connection = &'c mut PgConnection;
382
383    fn acquire(self) -> BoxFuture<'c, Result<Self::Connection, Error>> {
384        self.connection().boxed()
385    }
386
387    fn begin(self) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, Error>> {
388        self.connection().and_then(|c| c.begin()).boxed()
389    }
390}
391
392impl<'c> Executor<'c> for &'c mut PgListener {
393    type Database = Postgres;
394
395    fn fetch_many<'e, 'q, E>(
396        self,
397        query: E,
398    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
399    where
400        'c: 'e,
401        E: Execute<'q, Self::Database>,
402        'q: 'e,
403        E: 'q,
404    {
405        futures_util::stream::once(async move {
406            // need some basic type annotation to help the compiler a bit
407            let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
408            res
409        })
410        .try_flatten()
411        .boxed()
412    }
413
414    fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
415    where
416        'c: 'e,
417        E: Execute<'q, Self::Database>,
418        'q: 'e,
419        E: 'q,
420    {
421        async move { self.connection().await?.fetch_optional(query).await }.boxed()
422    }
423
424    fn prepare_with<'e>(
425        self,
426        query: SqlStr,
427        parameters: &'e [PgTypeInfo],
428    ) -> BoxFuture<'e, Result<PgStatement, Error>>
429    where
430        'c: 'e,
431    {
432        async move {
433            self.connection()
434                .await?
435                .prepare_with(query, parameters)
436                .await
437        }
438        .boxed()
439    }
440
441    #[doc(hidden)]
442    fn describe<'e>(self, query: SqlStr) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
443    where
444        'c: 'e,
445    {
446        async move { self.connection().await?.describe(query).await }.boxed()
447    }
448}
449
450impl PgNotification {
451    /// The process ID of the notifying backend process.
452    #[inline]
453    pub fn process_id(&self) -> u32 {
454        self.0.process_id
455    }
456
457    /// The channel that the notify has been raised on. This can be thought
458    /// of as the message topic.
459    #[inline]
460    pub fn channel(&self) -> &str {
461        from_utf8(&self.0.channel).unwrap()
462    }
463
464    /// The payload of the notification. An empty payload is received as an
465    /// empty string.
466    #[inline]
467    pub fn payload(&self) -> &str {
468        from_utf8(&self.0.payload).unwrap()
469    }
470}
471
472impl Debug for PgListener {
473    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
474        f.debug_struct("PgListener").finish()
475    }
476}
477
478impl Debug for PgNotification {
479    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
480        f.debug_struct("PgNotification")
481            .field("process_id", &self.process_id())
482            .field("channel", &self.channel())
483            .field("payload", &self.payload())
484            .finish()
485    }
486}
487
488fn ident(mut name: &str) -> String {
489    // If the input string contains a NUL byte, we should truncate the
490    // identifier.
491    if let Some(index) = name.find('\0') {
492        name = &name[..index];
493    }
494
495    // Any double quotes must be escaped
496    name.replace('"', "\"\"")
497}
498
499fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
500    channels.into_iter().fold(String::new(), |mut acc, chan| {
501        acc.push_str(r#"LISTEN ""#);
502        acc.push_str(&ident(chan.as_ref()));
503        acc.push_str(r#"";"#);
504        acc
505    })
506}
507
508#[test]
509fn test_build_listen_all_query_with_single_channel() {
510    let output = build_listen_all_query(["test"]);
511    assert_eq!(output.as_str(), r#"LISTEN "test";"#);
512}
513
514#[test]
515fn test_build_listen_all_query_with_multiple_channels() {
516    let output = build_listen_all_query(["channel.0", "channel.1"]);
517    assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
518}