sqlx_build_trust_postgres/
listener.rs

1use std::fmt::{self, Debug};
2use std::io;
3use std::str::from_utf8;
4
5use futures_channel::mpsc;
6use futures_core::future::BoxFuture;
7use futures_core::stream::{BoxStream, Stream};
8use futures_util::{FutureExt, StreamExt, TryStreamExt};
9use sqlx_core::Either;
10
11use crate::describe::Describe;
12use crate::error::Error;
13use crate::executor::{Execute, Executor};
14use crate::message::{MessageFormat, Notification};
15use crate::pool::PoolOptions;
16use crate::pool::{Pool, PoolConnection};
17use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
18
19/// A stream of asynchronous notifications from Postgres.
20///
21/// This listener will auto-reconnect. If the active
22/// connection being used ever dies, this listener will detect that event, create a
23/// new connection, will re-subscribe to all of the originally specified channels, and will resume
24/// operations as normal.
25pub struct PgListener {
26    pool: Pool<Postgres>,
27    connection: Option<PoolConnection<Postgres>>,
28    buffer_rx: mpsc::UnboundedReceiver<Notification>,
29    buffer_tx: Option<mpsc::UnboundedSender<Notification>>,
30    channels: Vec<String>,
31    ignore_close_event: bool,
32}
33
34/// An asynchronous notification from Postgres.
35pub struct PgNotification(Notification);
36
37impl PgListener {
38    pub async fn connect(url: &str) -> Result<Self, Error> {
39        // Create a pool of 1 without timeouts (as they don't apply here)
40        // We only use the pool to handle re-connections
41        let pool = PoolOptions::<Postgres>::new()
42            .max_connections(1)
43            .max_lifetime(None)
44            .idle_timeout(None)
45            .connect(url)
46            .await?;
47
48        let mut this = Self::connect_with(&pool).await?;
49        // We don't need to handle close events
50        this.ignore_close_event = true;
51
52        Ok(this)
53    }
54
55    pub async fn connect_with(pool: &Pool<Postgres>) -> Result<Self, Error> {
56        // Pull out an initial connection
57        let mut connection = pool.acquire().await?;
58
59        // Setup a notification buffer
60        let (sender, receiver) = mpsc::unbounded();
61        connection.stream.notifications = Some(sender);
62
63        Ok(Self {
64            pool: pool.clone(),
65            connection: Some(connection),
66            buffer_rx: receiver,
67            buffer_tx: None,
68            channels: Vec::new(),
69            ignore_close_event: false,
70        })
71    }
72
73    /// Set whether or not to ignore [`Pool::close_event()`]. Defaults to `false`.
74    ///
75    /// By default, when [`Pool::close()`] is called on the pool this listener is using
76    /// while [`Self::recv()`] or [`Self::try_recv()`] are waiting for a message, the wait is
77    /// cancelled and `Err(PoolClosed)` is returned.
78    ///
79    /// This is because `Pool::close()` will wait until _all_ connections are returned and closed,
80    /// including the one being used by this listener.
81    ///
82    /// Otherwise, `pool.close().await` would have to wait until `PgListener` encountered a
83    /// need to acquire a new connection (timeout, error, etc.) and dropped the one it was
84    /// currently holding, at which point `.recv()` or `.try_recv()` would return `Err(PoolClosed)`
85    /// on the attempt to acquire a new connection anyway.
86    ///
87    /// However, if you want `PgListener` to ignore the close event and continue waiting for a
88    /// message as long as it can, set this to `true`.
89    ///
90    /// Does nothing if this was constructed with [`PgListener::connect()`], as that creates an
91    /// internal pool just for the new instance of `PgListener` which cannot be closed manually.
92    pub fn ignore_pool_close_event(&mut self, val: bool) {
93        self.ignore_close_event = val;
94    }
95
96    /// Starts listening for notifications on a channel.
97    /// The channel name is quoted here to ensure case sensitivity.
98    pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
99        self.connection()
100            .await?
101            .execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
102            .await?;
103
104        self.channels.push(channel.to_owned());
105
106        Ok(())
107    }
108
109    /// Starts listening for notifications on all channels.
110    pub async fn listen_all(
111        &mut self,
112        channels: impl IntoIterator<Item = &str>,
113    ) -> Result<(), Error> {
114        let beg = self.channels.len();
115        self.channels.extend(channels.into_iter().map(|s| s.into()));
116
117        let query = build_listen_all_query(&self.channels[beg..]);
118        self.connection().await?.execute(&*query).await?;
119
120        Ok(())
121    }
122
123    /// Stops listening for notifications on a channel.
124    /// The channel name is quoted here to ensure case sensitivity.
125    pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
126        // use RAW connection and do NOT re-connect automatically, since this is not required for
127        // UNLISTEN (we've disconnected anyways)
128        if let Some(connection) = self.connection.as_mut() {
129            connection
130                .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
131                .await?;
132        }
133
134        if let Some(pos) = self.channels.iter().position(|s| s == channel) {
135            self.channels.remove(pos);
136        }
137
138        Ok(())
139    }
140
141    /// Stops listening for notifications on all channels.
142    pub async fn unlisten_all(&mut self) -> Result<(), Error> {
143        // use RAW connection and do NOT re-connect automatically, since this is not required for
144        // UNLISTEN (we've disconnected anyways)
145        if let Some(connection) = self.connection.as_mut() {
146            connection.execute("UNLISTEN *").await?;
147        }
148
149        self.channels.clear();
150
151        Ok(())
152    }
153
154    #[inline]
155    async fn connect_if_needed(&mut self) -> Result<(), Error> {
156        if self.connection.is_none() {
157            let mut connection = self.pool.acquire().await?;
158            connection.stream.notifications = self.buffer_tx.take();
159
160            connection
161                .execute(&*build_listen_all_query(&self.channels))
162                .await?;
163
164            self.connection = Some(connection);
165        }
166
167        Ok(())
168    }
169
170    #[inline]
171    async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
172        // Ensure we have an active connection to work with.
173        self.connect_if_needed().await?;
174
175        Ok(self.connection.as_mut().unwrap())
176    }
177
178    /// Receives the next notification available from any of the subscribed channels.
179    ///
180    /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next
181    /// call to `recv()`, and should be entirely transparent (as long as it was just an
182    /// intermittent network failure or long-lived connection reaper).
183    ///
184    /// As notifications are transient, any received while the connection was lost, will not
185    /// be returned. If you'd prefer the reconnection to be explicit and have a chance to
186    /// do something before, please see [`try_recv`](Self::try_recv).
187    ///
188    /// # Example
189    ///
190    /// ```rust,no_run
191    /// # use sqlx_core::postgres::PgListener;
192    /// # use sqlx_core::error::Error;
193    /// #
194    /// # #[cfg(feature = "_rt")]
195    /// # sqlx::__rt::test_block_on(async move {
196    /// # let mut listener = PgListener::connect("postgres:// ...").await?;
197    /// loop {
198    ///     // ask for next notification, re-connecting (transparently) if needed
199    ///     let notification = listener.recv().await?;
200    ///
201    ///     // handle notification, do something interesting
202    /// }
203    /// # Result::<(), Error>::Ok(())
204    /// # }).unwrap();
205    /// ```
206    pub async fn recv(&mut self) -> Result<PgNotification, Error> {
207        loop {
208            if let Some(notification) = self.try_recv().await? {
209                return Ok(notification);
210            }
211        }
212    }
213
214    /// Receives the next notification available from any of the subscribed channels.
215    ///
216    /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is
217    /// reconnected on the next call to `try_recv()`.
218    ///
219    /// # Example
220    ///
221    /// ```rust,no_run
222    /// # use sqlx_core::postgres::PgListener;
223    /// # use sqlx_core::error::Error;
224    /// #
225    /// # #[cfg(feature = "_rt")]
226    /// # sqlx::__rt::test_block_on(async move {
227    /// # let mut listener = PgListener::connect("postgres:// ...").await?;
228    /// loop {
229    ///     // start handling notifications, connecting if needed
230    ///     while let Some(notification) = listener.try_recv().await? {
231    ///         // handle notification
232    ///     }
233    ///
234    ///     // connection lost, do something interesting
235    /// }
236    /// # Result::<(), Error>::Ok(())
237    /// # }).unwrap();
238    /// ```
239    pub async fn try_recv(&mut self) -> Result<Option<PgNotification>, Error> {
240        // Flush the buffer first, if anything
241        // This would only fill up if this listener is used as a connection
242        if let Ok(Some(notification)) = self.buffer_rx.try_next() {
243            return Ok(Some(PgNotification(notification)));
244        }
245
246        // Fetch our `CloseEvent` listener, if applicable.
247        let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
248
249        loop {
250            let next_message = self.connection().await?.stream.recv_unchecked();
251
252            let res = if let Some(ref mut close_event) = close_event {
253                // cancels the wait and returns `Err(PoolClosed)` if the pool is closed
254                // before `next_message` returns, or if the pool was already closed
255                close_event.do_until(next_message).await?
256            } else {
257                next_message.await
258            };
259
260            let message = match res {
261                Ok(message) => message,
262
263                // The connection is dead, ensure that it is dropped,
264                // update self state, and loop to try again.
265                Err(Error::Io(err))
266                    if (err.kind() == io::ErrorKind::ConnectionAborted
267                        || err.kind() == io::ErrorKind::UnexpectedEof) =>
268                {
269                    self.buffer_tx = self.connection().await?.stream.notifications.take();
270                    self.connection = None;
271
272                    // lost connection
273                    return Ok(None);
274                }
275
276                // Forward other errors
277                Err(error) => {
278                    return Err(error);
279                }
280            };
281
282            match message.format {
283                // We've received an async notification, return it.
284                MessageFormat::NotificationResponse => {
285                    return Ok(Some(PgNotification(message.decode()?)));
286                }
287
288                // Mark the connection as ready for another query
289                MessageFormat::ReadyForQuery => {
290                    self.connection().await?.pending_ready_for_query_count -= 1;
291                }
292
293                // Ignore unexpected messages
294                _ => {}
295            }
296        }
297    }
298
299    /// Consume this listener, returning a `Stream` of notifications.
300    ///
301    /// The backing connection will be automatically reconnected should it be lost.
302    ///
303    /// This has the same potential drawbacks as [`recv`](PgListener::recv).
304    ///
305    pub fn into_stream(mut self) -> impl Stream<Item = Result<PgNotification, Error>> + Unpin {
306        Box::pin(try_stream! {
307            loop {
308                r#yield!(self.recv().await?);
309            }
310        })
311    }
312}
313
314impl Drop for PgListener {
315    fn drop(&mut self) {
316        if let Some(mut conn) = self.connection.take() {
317            let fut = async move {
318                let _ = conn.execute("UNLISTEN *").await;
319
320                // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task
321                // otherwise, it may trigger a panic if this task is dropped because the runtime is going away:
322                // https://github.com/launchbadge/sqlx/issues/1389
323                conn.return_to_pool().await;
324            };
325
326            // Unregister any listeners before returning the connection to the pool.
327            crate::rt::spawn(fut);
328        }
329    }
330}
331
332impl<'c> Executor<'c> for &'c mut PgListener {
333    type Database = Postgres;
334
335    fn fetch_many<'e, 'q: 'e, E: 'q>(
336        self,
337        query: E,
338    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
339    where
340        'c: 'e,
341        E: Execute<'q, Self::Database>,
342    {
343        futures_util::stream::once(async move {
344            // need some basic type annotation to help the compiler a bit
345            let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
346            res
347        })
348        .try_flatten()
349        .boxed()
350    }
351
352    fn fetch_optional<'e, 'q: 'e, E: 'q>(
353        self,
354        query: E,
355    ) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
356    where
357        'c: 'e,
358        E: Execute<'q, Self::Database>,
359    {
360        async move { self.connection().await?.fetch_optional(query).await }.boxed()
361    }
362
363    fn prepare_with<'e, 'q: 'e>(
364        self,
365        query: &'q str,
366        parameters: &'e [PgTypeInfo],
367    ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
368    where
369        'c: 'e,
370    {
371        async move {
372            self.connection()
373                .await?
374                .prepare_with(query, parameters)
375                .await
376        }
377        .boxed()
378    }
379
380    #[doc(hidden)]
381    fn describe<'e, 'q: 'e>(
382        self,
383        query: &'q str,
384    ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
385    where
386        'c: 'e,
387    {
388        async move { self.connection().await?.describe(query).await }.boxed()
389    }
390}
391
392impl PgNotification {
393    /// The process ID of the notifying backend process.
394    #[inline]
395    pub fn process_id(&self) -> u32 {
396        self.0.process_id
397    }
398
399    /// The channel that the notify has been raised on. This can be thought
400    /// of as the message topic.
401    #[inline]
402    pub fn channel(&self) -> &str {
403        from_utf8(&self.0.channel).unwrap()
404    }
405
406    /// The payload of the notification. An empty payload is received as an
407    /// empty string.
408    #[inline]
409    pub fn payload(&self) -> &str {
410        from_utf8(&self.0.payload).unwrap()
411    }
412}
413
414impl Debug for PgListener {
415    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416        f.debug_struct("PgListener").finish()
417    }
418}
419
420impl Debug for PgNotification {
421    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
422        f.debug_struct("PgNotification")
423            .field("process_id", &self.process_id())
424            .field("channel", &self.channel())
425            .field("payload", &self.payload())
426            .finish()
427    }
428}
429
430fn ident(mut name: &str) -> String {
431    // If the input string contains a NUL byte, we should truncate the
432    // identifier.
433    if let Some(index) = name.find('\0') {
434        name = &name[..index];
435    }
436
437    // Any double quotes must be escaped
438    name.replace('"', "\"\"")
439}
440
441fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
442    channels.into_iter().fold(String::new(), |mut acc, chan| {
443        acc.push_str(r#"LISTEN ""#);
444        acc.push_str(&ident(chan.as_ref()));
445        acc.push_str(r#"";"#);
446        acc
447    })
448}
449
450#[test]
451fn test_build_listen_all_query_with_single_channel() {
452    let output = build_listen_all_query(&["test"]);
453    assert_eq!(output.as_str(), r#"LISTEN "test";"#);
454}
455
456#[test]
457fn test_build_listen_all_query_with_multiple_channels() {
458    let output = build_listen_all_query(&["channel.0", "channel.1"]);
459    assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
460}