sqlx_postgres/listener.rs
1use std::fmt::{self, Debug};
2use std::io;
3use std::str::from_utf8;
4
5use futures_channel::mpsc;
6use futures_core::future::BoxFuture;
7use futures_core::stream::{BoxStream, Stream};
8use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
9use sqlx_core::acquire::Acquire;
10use sqlx_core::transaction::Transaction;
11use sqlx_core::Either;
12use tracing::Instrument;
13
14use crate::describe::Describe;
15use crate::error::Error;
16use crate::executor::{Execute, Executor};
17use crate::message::{BackendMessageFormat, Notification};
18use crate::pool::PoolOptions;
19use crate::pool::{Pool, PoolConnection};
20use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
21
22/// A stream of asynchronous notifications from Postgres.
23///
24/// This listener will auto-reconnect. If the active
25/// connection being used ever dies, this listener will detect that event, create a
26/// new connection, will re-subscribe to all of the originally specified channels, and will resume
27/// operations as normal.
28pub struct PgListener {
29 pool: Pool<Postgres>,
30 connection: Option<PoolConnection<Postgres>>,
31 buffer_rx: mpsc::UnboundedReceiver<Notification>,
32 buffer_tx: Option<mpsc::UnboundedSender<Notification>>,
33 channels: Vec<String>,
34 ignore_close_event: bool,
35 eager_reconnect: bool,
36}
37
38/// An asynchronous notification from Postgres.
39pub struct PgNotification(Notification);
40
41impl PgListener {
42 pub async fn connect(url: &str) -> Result<Self, Error> {
43 // Create a pool of 1 without timeouts (as they don't apply here)
44 // We only use the pool to handle re-connections
45 let pool = PoolOptions::<Postgres>::new()
46 .max_connections(1)
47 .max_lifetime(None)
48 .idle_timeout(None)
49 .connect(url)
50 .await?;
51
52 let mut this = Self::connect_with(&pool).await?;
53 // We don't need to handle close events
54 this.ignore_close_event = true;
55
56 Ok(this)
57 }
58
59 pub async fn connect_with(pool: &Pool<Postgres>) -> Result<Self, Error> {
60 // Pull out an initial connection
61 let mut connection = pool.acquire().await?;
62
63 // Setup a notification buffer
64 let (sender, receiver) = mpsc::unbounded();
65 connection.inner.stream.notifications = Some(sender);
66
67 Ok(Self {
68 pool: pool.clone(),
69 connection: Some(connection),
70 buffer_rx: receiver,
71 buffer_tx: None,
72 channels: Vec::new(),
73 ignore_close_event: false,
74 eager_reconnect: true,
75 })
76 }
77
78 /// Set whether or not to ignore [`Pool::close_event()`]. Defaults to `false`.
79 ///
80 /// By default, when [`Pool::close()`] is called on the pool this listener is using
81 /// while [`Self::recv()`] or [`Self::try_recv()`] are waiting for a message, the wait is
82 /// cancelled and `Err(PoolClosed)` is returned.
83 ///
84 /// This is because `Pool::close()` will wait until _all_ connections are returned and closed,
85 /// including the one being used by this listener.
86 ///
87 /// Otherwise, `pool.close().await` would have to wait until `PgListener` encountered a
88 /// need to acquire a new connection (timeout, error, etc.) and dropped the one it was
89 /// currently holding, at which point `.recv()` or `.try_recv()` would return `Err(PoolClosed)`
90 /// on the attempt to acquire a new connection anyway.
91 ///
92 /// However, if you want `PgListener` to ignore the close event and continue waiting for a
93 /// message as long as it can, set this to `true`.
94 ///
95 /// Does nothing if this was constructed with [`PgListener::connect()`], as that creates an
96 /// internal pool just for the new instance of `PgListener` which cannot be closed manually.
97 pub fn ignore_pool_close_event(&mut self, val: bool) {
98 self.ignore_close_event = val;
99 }
100
101 /// Set whether a lost connection in `try_recv()` should be re-established before it returns
102 /// `Ok(None)`, or on the next call to `try_recv()`.
103 ///
104 /// By default, this is `true` and the connection is re-established before returning `Ok(None)`.
105 ///
106 /// If this is set to `false` then notifications will continue to be lost until the next call
107 /// to `try_recv()`. If your recovery logic uses a different database connection then
108 /// notifications that occur after it completes may be lost without any way to tell that they
109 /// have been.
110 pub fn eager_reconnect(&mut self, val: bool) {
111 self.eager_reconnect = val;
112 }
113
114 /// Starts listening for notifications on a channel.
115 /// The channel name is quoted here to ensure case sensitivity.
116 pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
117 self.connection()
118 .await?
119 .execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
120 .await?;
121
122 self.channels.push(channel.to_owned());
123
124 Ok(())
125 }
126
127 /// Starts listening for notifications on all channels.
128 pub async fn listen_all(
129 &mut self,
130 channels: impl IntoIterator<Item = &str>,
131 ) -> Result<(), Error> {
132 let beg = self.channels.len();
133 self.channels.extend(channels.into_iter().map(|s| s.into()));
134
135 let query = build_listen_all_query(&self.channels[beg..]);
136 self.connection().await?.execute(&*query).await?;
137
138 Ok(())
139 }
140
141 /// Stops listening for notifications on a channel.
142 /// The channel name is quoted here to ensure case sensitivity.
143 pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
144 // use RAW connection and do NOT re-connect automatically, since this is not required for
145 // UNLISTEN (we've disconnected anyways)
146 if let Some(connection) = self.connection.as_mut() {
147 connection
148 .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
149 .await?;
150 }
151
152 if let Some(pos) = self.channels.iter().position(|s| s == channel) {
153 self.channels.remove(pos);
154 }
155
156 Ok(())
157 }
158
159 /// Stops listening for notifications on all channels.
160 pub async fn unlisten_all(&mut self) -> Result<(), Error> {
161 // use RAW connection and do NOT re-connect automatically, since this is not required for
162 // UNLISTEN (we've disconnected anyways)
163 if let Some(connection) = self.connection.as_mut() {
164 connection.execute("UNLISTEN *").await?;
165 }
166
167 self.channels.clear();
168
169 Ok(())
170 }
171
172 #[inline]
173 async fn connect_if_needed(&mut self) -> Result<(), Error> {
174 if self.connection.is_none() {
175 let mut connection = self.pool.acquire().await?;
176 connection.inner.stream.notifications = self.buffer_tx.take();
177
178 connection
179 .execute(&*build_listen_all_query(&self.channels))
180 .await?;
181
182 self.connection = Some(connection);
183 }
184
185 Ok(())
186 }
187
188 #[inline]
189 async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
190 // Ensure we have an active connection to work with.
191 self.connect_if_needed().await?;
192
193 Ok(self.connection.as_mut().unwrap())
194 }
195
196 /// Receives the next notification available from any of the subscribed channels.
197 ///
198 /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next
199 /// call to `recv()`, and should be entirely transparent (as long as it was just an
200 /// intermittent network failure or long-lived connection reaper).
201 ///
202 /// As notifications are transient, any received while the connection was lost, will not
203 /// be returned. If you'd prefer the reconnection to be explicit and have a chance to
204 /// do something before, please see [`try_recv`](Self::try_recv).
205 ///
206 /// # Example
207 ///
208 /// ```rust,no_run
209 /// # use sqlx::postgres::PgListener;
210 /// #
211 /// # sqlx::__rt::test_block_on(async move {
212 /// let mut listener = PgListener::connect("postgres:// ...").await?;
213 /// loop {
214 /// // ask for next notification, re-connecting (transparently) if needed
215 /// let notification = listener.recv().await?;
216 ///
217 /// // handle notification, do something interesting
218 /// }
219 /// # Result::<(), sqlx::Error>::Ok(())
220 /// # }).unwrap();
221 /// ```
222 pub async fn recv(&mut self) -> Result<PgNotification, Error> {
223 loop {
224 if let Some(notification) = self.try_recv().await? {
225 return Ok(notification);
226 }
227 }
228 }
229
230 /// Receives the next notification available from any of the subscribed channels.
231 ///
232 /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is
233 /// reconnected either immediately, or on the next call to `try_recv()`, depending on
234 /// the value of [`eager_reconnect`].
235 ///
236 /// # Example
237 ///
238 /// ```rust,no_run
239 /// # use sqlx::postgres::PgListener;
240 /// #
241 /// # sqlx::__rt::test_block_on(async move {
242 /// # let mut listener = PgListener::connect("postgres:// ...").await?;
243 /// loop {
244 /// // start handling notifications, connecting if needed
245 /// while let Some(notification) = listener.try_recv().await? {
246 /// // handle notification
247 /// }
248 ///
249 /// // connection lost, do something interesting
250 /// }
251 /// # Result::<(), sqlx::Error>::Ok(())
252 /// # }).unwrap();
253 /// ```
254 ///
255 /// [`eager_reconnect`]: PgListener::eager_reconnect
256 pub async fn try_recv(&mut self) -> Result<Option<PgNotification>, Error> {
257 // Flush the buffer first, if anything
258 // This would only fill up if this listener is used as a connection
259 if let Some(notification) = self.next_buffered() {
260 return Ok(Some(notification));
261 }
262
263 // Fetch our `CloseEvent` listener, if applicable.
264 let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
265
266 loop {
267 let next_message = self.connection().await?.inner.stream.recv_unchecked();
268
269 let res = if let Some(ref mut close_event) = close_event {
270 // cancels the wait and returns `Err(PoolClosed)` if the pool is closed
271 // before `next_message` returns, or if the pool was already closed
272 close_event.do_until(next_message).await?
273 } else {
274 next_message.await
275 };
276
277 let message = match res {
278 Ok(message) => message,
279
280 // The connection is dead, ensure that it is dropped,
281 // update self state, and loop to try again.
282 Err(Error::Io(err))
283 if matches!(
284 err.kind(),
285 io::ErrorKind::ConnectionAborted |
286 io::ErrorKind::UnexpectedEof |
287 // see ERRORS section in tcp(7) man page (https://man7.org/linux/man-pages/man7/tcp.7.html)
288 io::ErrorKind::TimedOut |
289 io::ErrorKind::BrokenPipe
290 ) =>
291 {
292 if let Some(mut conn) = self.connection.take() {
293 self.buffer_tx = conn.inner.stream.notifications.take();
294 // Close the connection in a background task, so we can continue.
295 conn.close_on_drop();
296 }
297
298 if self.eager_reconnect {
299 self.connect_if_needed().await?;
300 }
301
302 // lost connection
303 return Ok(None);
304 }
305
306 // Forward other errors
307 Err(error) => {
308 return Err(error);
309 }
310 };
311
312 match message.format {
313 // We've received an async notification, return it.
314 BackendMessageFormat::NotificationResponse => {
315 return Ok(Some(PgNotification(message.decode()?)));
316 }
317
318 // Mark the connection as ready for another query
319 BackendMessageFormat::ReadyForQuery => {
320 self.connection().await?.inner.pending_ready_for_query_count -= 1;
321 }
322
323 // Ignore unexpected messages
324 _ => {}
325 }
326 }
327 }
328
329 /// Receives the next notification that already exists in the connection buffer, if any.
330 ///
331 /// This is similar to `try_recv`, except it will not wait if the connection has not yet received a notification.
332 ///
333 /// This is helpful if you want to retrieve all buffered notifications and process them in batches.
334 pub fn next_buffered(&mut self) -> Option<PgNotification> {
335 if let Ok(Some(notification)) = self.buffer_rx.try_next() {
336 Some(PgNotification(notification))
337 } else {
338 None
339 }
340 }
341
342 /// Consume this listener, returning a `Stream` of notifications.
343 ///
344 /// The backing connection will be automatically reconnected should it be lost.
345 ///
346 /// This has the same potential drawbacks as [`recv`](PgListener::recv).
347 ///
348 pub fn into_stream(mut self) -> impl Stream<Item = Result<PgNotification, Error>> + Unpin {
349 Box::pin(try_stream! {
350 loop {
351 r#yield!(self.recv().await?);
352 }
353 })
354 }
355}
356
357impl Drop for PgListener {
358 fn drop(&mut self) {
359 if let Some(mut conn) = self.connection.take() {
360 let fut = async move {
361 let _ = conn.execute("UNLISTEN *").await;
362
363 // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task
364 // otherwise, it may trigger a panic if this task is dropped because the runtime is going away:
365 // https://github.com/launchbadge/sqlx/issues/1389
366 conn.return_to_pool().await;
367 };
368
369 // Unregister any listeners before returning the connection to the pool.
370 crate::rt::spawn(fut.in_current_span());
371 }
372 }
373}
374
375impl<'c> Acquire<'c> for &'c mut PgListener {
376 type Database = Postgres;
377 type Connection = &'c mut PgConnection;
378
379 fn acquire(self) -> BoxFuture<'c, Result<Self::Connection, Error>> {
380 self.connection().boxed()
381 }
382
383 fn begin(self) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, Error>> {
384 self.connection().and_then(|c| c.begin()).boxed()
385 }
386}
387
388impl<'c> Executor<'c> for &'c mut PgListener {
389 type Database = Postgres;
390
391 fn fetch_many<'e, 'q, E>(
392 self,
393 query: E,
394 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
395 where
396 'c: 'e,
397 E: Execute<'q, Self::Database>,
398 'q: 'e,
399 E: 'q,
400 {
401 futures_util::stream::once(async move {
402 // need some basic type annotation to help the compiler a bit
403 let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
404 res
405 })
406 .try_flatten()
407 .boxed()
408 }
409
410 fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
411 where
412 'c: 'e,
413 E: Execute<'q, Self::Database>,
414 'q: 'e,
415 E: 'q,
416 {
417 async move { self.connection().await?.fetch_optional(query).await }.boxed()
418 }
419
420 fn prepare_with<'e, 'q: 'e>(
421 self,
422 query: &'q str,
423 parameters: &'e [PgTypeInfo],
424 ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
425 where
426 'c: 'e,
427 {
428 async move {
429 self.connection()
430 .await?
431 .prepare_with(query, parameters)
432 .await
433 }
434 .boxed()
435 }
436
437 #[doc(hidden)]
438 fn describe<'e, 'q: 'e>(
439 self,
440 query: &'q str,
441 ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
442 where
443 'c: 'e,
444 {
445 async move { self.connection().await?.describe(query).await }.boxed()
446 }
447}
448
449impl PgNotification {
450 /// The process ID of the notifying backend process.
451 #[inline]
452 pub fn process_id(&self) -> u32 {
453 self.0.process_id
454 }
455
456 /// The channel that the notify has been raised on. This can be thought
457 /// of as the message topic.
458 #[inline]
459 pub fn channel(&self) -> &str {
460 from_utf8(&self.0.channel).unwrap()
461 }
462
463 /// The payload of the notification. An empty payload is received as an
464 /// empty string.
465 #[inline]
466 pub fn payload(&self) -> &str {
467 from_utf8(&self.0.payload).unwrap()
468 }
469}
470
471impl Debug for PgListener {
472 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
473 f.debug_struct("PgListener").finish()
474 }
475}
476
477impl Debug for PgNotification {
478 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
479 f.debug_struct("PgNotification")
480 .field("process_id", &self.process_id())
481 .field("channel", &self.channel())
482 .field("payload", &self.payload())
483 .finish()
484 }
485}
486
487fn ident(mut name: &str) -> String {
488 // If the input string contains a NUL byte, we should truncate the
489 // identifier.
490 if let Some(index) = name.find('\0') {
491 name = &name[..index];
492 }
493
494 // Any double quotes must be escaped
495 name.replace('"', "\"\"")
496}
497
498fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
499 channels.into_iter().fold(String::new(), |mut acc, chan| {
500 acc.push_str(r#"LISTEN ""#);
501 acc.push_str(&ident(chan.as_ref()));
502 acc.push_str(r#"";"#);
503 acc
504 })
505}
506
507#[test]
508fn test_build_listen_all_query_with_single_channel() {
509 let output = build_listen_all_query(&["test"]);
510 assert_eq!(output.as_str(), r#"LISTEN "test";"#);
511}
512
513#[test]
514fn test_build_listen_all_query_with_multiple_channels() {
515 let output = build_listen_all_query(&["channel.0", "channel.1"]);
516 assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
517}