sqlx_build_trust_postgres/listener.rs
1use std::fmt::{self, Debug};
2use std::io;
3use std::str::from_utf8;
4
5use futures_channel::mpsc;
6use futures_core::future::BoxFuture;
7use futures_core::stream::{BoxStream, Stream};
8use futures_util::{FutureExt, StreamExt, TryStreamExt};
9use sqlx_core::Either;
10
11use crate::describe::Describe;
12use crate::error::Error;
13use crate::executor::{Execute, Executor};
14use crate::message::{MessageFormat, Notification};
15use crate::pool::PoolOptions;
16use crate::pool::{Pool, PoolConnection};
17use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
18
19/// A stream of asynchronous notifications from Postgres.
20///
21/// This listener will auto-reconnect. If the active
22/// connection being used ever dies, this listener will detect that event, create a
23/// new connection, will re-subscribe to all of the originally specified channels, and will resume
24/// operations as normal.
25pub struct PgListener {
26 pool: Pool<Postgres>,
27 connection: Option<PoolConnection<Postgres>>,
28 buffer_rx: mpsc::UnboundedReceiver<Notification>,
29 buffer_tx: Option<mpsc::UnboundedSender<Notification>>,
30 channels: Vec<String>,
31 ignore_close_event: bool,
32}
33
34/// An asynchronous notification from Postgres.
35pub struct PgNotification(Notification);
36
37impl PgListener {
38 pub async fn connect(url: &str) -> Result<Self, Error> {
39 // Create a pool of 1 without timeouts (as they don't apply here)
40 // We only use the pool to handle re-connections
41 let pool = PoolOptions::<Postgres>::new()
42 .max_connections(1)
43 .max_lifetime(None)
44 .idle_timeout(None)
45 .connect(url)
46 .await?;
47
48 let mut this = Self::connect_with(&pool).await?;
49 // We don't need to handle close events
50 this.ignore_close_event = true;
51
52 Ok(this)
53 }
54
55 pub async fn connect_with(pool: &Pool<Postgres>) -> Result<Self, Error> {
56 // Pull out an initial connection
57 let mut connection = pool.acquire().await?;
58
59 // Setup a notification buffer
60 let (sender, receiver) = mpsc::unbounded();
61 connection.stream.notifications = Some(sender);
62
63 Ok(Self {
64 pool: pool.clone(),
65 connection: Some(connection),
66 buffer_rx: receiver,
67 buffer_tx: None,
68 channels: Vec::new(),
69 ignore_close_event: false,
70 })
71 }
72
73 /// Set whether or not to ignore [`Pool::close_event()`]. Defaults to `false`.
74 ///
75 /// By default, when [`Pool::close()`] is called on the pool this listener is using
76 /// while [`Self::recv()`] or [`Self::try_recv()`] are waiting for a message, the wait is
77 /// cancelled and `Err(PoolClosed)` is returned.
78 ///
79 /// This is because `Pool::close()` will wait until _all_ connections are returned and closed,
80 /// including the one being used by this listener.
81 ///
82 /// Otherwise, `pool.close().await` would have to wait until `PgListener` encountered a
83 /// need to acquire a new connection (timeout, error, etc.) and dropped the one it was
84 /// currently holding, at which point `.recv()` or `.try_recv()` would return `Err(PoolClosed)`
85 /// on the attempt to acquire a new connection anyway.
86 ///
87 /// However, if you want `PgListener` to ignore the close event and continue waiting for a
88 /// message as long as it can, set this to `true`.
89 ///
90 /// Does nothing if this was constructed with [`PgListener::connect()`], as that creates an
91 /// internal pool just for the new instance of `PgListener` which cannot be closed manually.
92 pub fn ignore_pool_close_event(&mut self, val: bool) {
93 self.ignore_close_event = val;
94 }
95
96 /// Starts listening for notifications on a channel.
97 /// The channel name is quoted here to ensure case sensitivity.
98 pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
99 self.connection()
100 .await?
101 .execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
102 .await?;
103
104 self.channels.push(channel.to_owned());
105
106 Ok(())
107 }
108
109 /// Starts listening for notifications on all channels.
110 pub async fn listen_all(
111 &mut self,
112 channels: impl IntoIterator<Item = &str>,
113 ) -> Result<(), Error> {
114 let beg = self.channels.len();
115 self.channels.extend(channels.into_iter().map(|s| s.into()));
116
117 let query = build_listen_all_query(&self.channels[beg..]);
118 self.connection().await?.execute(&*query).await?;
119
120 Ok(())
121 }
122
123 /// Stops listening for notifications on a channel.
124 /// The channel name is quoted here to ensure case sensitivity.
125 pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
126 // use RAW connection and do NOT re-connect automatically, since this is not required for
127 // UNLISTEN (we've disconnected anyways)
128 if let Some(connection) = self.connection.as_mut() {
129 connection
130 .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
131 .await?;
132 }
133
134 if let Some(pos) = self.channels.iter().position(|s| s == channel) {
135 self.channels.remove(pos);
136 }
137
138 Ok(())
139 }
140
141 /// Stops listening for notifications on all channels.
142 pub async fn unlisten_all(&mut self) -> Result<(), Error> {
143 // use RAW connection and do NOT re-connect automatically, since this is not required for
144 // UNLISTEN (we've disconnected anyways)
145 if let Some(connection) = self.connection.as_mut() {
146 connection.execute("UNLISTEN *").await?;
147 }
148
149 self.channels.clear();
150
151 Ok(())
152 }
153
154 #[inline]
155 async fn connect_if_needed(&mut self) -> Result<(), Error> {
156 if self.connection.is_none() {
157 let mut connection = self.pool.acquire().await?;
158 connection.stream.notifications = self.buffer_tx.take();
159
160 connection
161 .execute(&*build_listen_all_query(&self.channels))
162 .await?;
163
164 self.connection = Some(connection);
165 }
166
167 Ok(())
168 }
169
170 #[inline]
171 async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
172 // Ensure we have an active connection to work with.
173 self.connect_if_needed().await?;
174
175 Ok(self.connection.as_mut().unwrap())
176 }
177
178 /// Receives the next notification available from any of the subscribed channels.
179 ///
180 /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next
181 /// call to `recv()`, and should be entirely transparent (as long as it was just an
182 /// intermittent network failure or long-lived connection reaper).
183 ///
184 /// As notifications are transient, any received while the connection was lost, will not
185 /// be returned. If you'd prefer the reconnection to be explicit and have a chance to
186 /// do something before, please see [`try_recv`](Self::try_recv).
187 ///
188 /// # Example
189 ///
190 /// ```rust,no_run
191 /// # use sqlx_core::postgres::PgListener;
192 /// # use sqlx_core::error::Error;
193 /// #
194 /// # #[cfg(feature = "_rt")]
195 /// # sqlx::__rt::test_block_on(async move {
196 /// # let mut listener = PgListener::connect("postgres:// ...").await?;
197 /// loop {
198 /// // ask for next notification, re-connecting (transparently) if needed
199 /// let notification = listener.recv().await?;
200 ///
201 /// // handle notification, do something interesting
202 /// }
203 /// # Result::<(), Error>::Ok(())
204 /// # }).unwrap();
205 /// ```
206 pub async fn recv(&mut self) -> Result<PgNotification, Error> {
207 loop {
208 if let Some(notification) = self.try_recv().await? {
209 return Ok(notification);
210 }
211 }
212 }
213
214 /// Receives the next notification available from any of the subscribed channels.
215 ///
216 /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is
217 /// reconnected on the next call to `try_recv()`.
218 ///
219 /// # Example
220 ///
221 /// ```rust,no_run
222 /// # use sqlx_core::postgres::PgListener;
223 /// # use sqlx_core::error::Error;
224 /// #
225 /// # #[cfg(feature = "_rt")]
226 /// # sqlx::__rt::test_block_on(async move {
227 /// # let mut listener = PgListener::connect("postgres:// ...").await?;
228 /// loop {
229 /// // start handling notifications, connecting if needed
230 /// while let Some(notification) = listener.try_recv().await? {
231 /// // handle notification
232 /// }
233 ///
234 /// // connection lost, do something interesting
235 /// }
236 /// # Result::<(), Error>::Ok(())
237 /// # }).unwrap();
238 /// ```
239 pub async fn try_recv(&mut self) -> Result<Option<PgNotification>, Error> {
240 // Flush the buffer first, if anything
241 // This would only fill up if this listener is used as a connection
242 if let Ok(Some(notification)) = self.buffer_rx.try_next() {
243 return Ok(Some(PgNotification(notification)));
244 }
245
246 // Fetch our `CloseEvent` listener, if applicable.
247 let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
248
249 loop {
250 let next_message = self.connection().await?.stream.recv_unchecked();
251
252 let res = if let Some(ref mut close_event) = close_event {
253 // cancels the wait and returns `Err(PoolClosed)` if the pool is closed
254 // before `next_message` returns, or if the pool was already closed
255 close_event.do_until(next_message).await?
256 } else {
257 next_message.await
258 };
259
260 let message = match res {
261 Ok(message) => message,
262
263 // The connection is dead, ensure that it is dropped,
264 // update self state, and loop to try again.
265 Err(Error::Io(err))
266 if (err.kind() == io::ErrorKind::ConnectionAborted
267 || err.kind() == io::ErrorKind::UnexpectedEof) =>
268 {
269 self.buffer_tx = self.connection().await?.stream.notifications.take();
270 self.connection = None;
271
272 // lost connection
273 return Ok(None);
274 }
275
276 // Forward other errors
277 Err(error) => {
278 return Err(error);
279 }
280 };
281
282 match message.format {
283 // We've received an async notification, return it.
284 MessageFormat::NotificationResponse => {
285 return Ok(Some(PgNotification(message.decode()?)));
286 }
287
288 // Mark the connection as ready for another query
289 MessageFormat::ReadyForQuery => {
290 self.connection().await?.pending_ready_for_query_count -= 1;
291 }
292
293 // Ignore unexpected messages
294 _ => {}
295 }
296 }
297 }
298
299 /// Consume this listener, returning a `Stream` of notifications.
300 ///
301 /// The backing connection will be automatically reconnected should it be lost.
302 ///
303 /// This has the same potential drawbacks as [`recv`](PgListener::recv).
304 ///
305 pub fn into_stream(mut self) -> impl Stream<Item = Result<PgNotification, Error>> + Unpin {
306 Box::pin(try_stream! {
307 loop {
308 r#yield!(self.recv().await?);
309 }
310 })
311 }
312}
313
314impl Drop for PgListener {
315 fn drop(&mut self) {
316 if let Some(mut conn) = self.connection.take() {
317 let fut = async move {
318 let _ = conn.execute("UNLISTEN *").await;
319
320 // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task
321 // otherwise, it may trigger a panic if this task is dropped because the runtime is going away:
322 // https://github.com/launchbadge/sqlx/issues/1389
323 conn.return_to_pool().await;
324 };
325
326 // Unregister any listeners before returning the connection to the pool.
327 crate::rt::spawn(fut);
328 }
329 }
330}
331
332impl<'c> Executor<'c> for &'c mut PgListener {
333 type Database = Postgres;
334
335 fn fetch_many<'e, 'q: 'e, E: 'q>(
336 self,
337 query: E,
338 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
339 where
340 'c: 'e,
341 E: Execute<'q, Self::Database>,
342 {
343 futures_util::stream::once(async move {
344 // need some basic type annotation to help the compiler a bit
345 let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
346 res
347 })
348 .try_flatten()
349 .boxed()
350 }
351
352 fn fetch_optional<'e, 'q: 'e, E: 'q>(
353 self,
354 query: E,
355 ) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
356 where
357 'c: 'e,
358 E: Execute<'q, Self::Database>,
359 {
360 async move { self.connection().await?.fetch_optional(query).await }.boxed()
361 }
362
363 fn prepare_with<'e, 'q: 'e>(
364 self,
365 query: &'q str,
366 parameters: &'e [PgTypeInfo],
367 ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
368 where
369 'c: 'e,
370 {
371 async move {
372 self.connection()
373 .await?
374 .prepare_with(query, parameters)
375 .await
376 }
377 .boxed()
378 }
379
380 #[doc(hidden)]
381 fn describe<'e, 'q: 'e>(
382 self,
383 query: &'q str,
384 ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
385 where
386 'c: 'e,
387 {
388 async move { self.connection().await?.describe(query).await }.boxed()
389 }
390}
391
392impl PgNotification {
393 /// The process ID of the notifying backend process.
394 #[inline]
395 pub fn process_id(&self) -> u32 {
396 self.0.process_id
397 }
398
399 /// The channel that the notify has been raised on. This can be thought
400 /// of as the message topic.
401 #[inline]
402 pub fn channel(&self) -> &str {
403 from_utf8(&self.0.channel).unwrap()
404 }
405
406 /// The payload of the notification. An empty payload is received as an
407 /// empty string.
408 #[inline]
409 pub fn payload(&self) -> &str {
410 from_utf8(&self.0.payload).unwrap()
411 }
412}
413
414impl Debug for PgListener {
415 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416 f.debug_struct("PgListener").finish()
417 }
418}
419
420impl Debug for PgNotification {
421 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
422 f.debug_struct("PgNotification")
423 .field("process_id", &self.process_id())
424 .field("channel", &self.channel())
425 .field("payload", &self.payload())
426 .finish()
427 }
428}
429
430fn ident(mut name: &str) -> String {
431 // If the input string contains a NUL byte, we should truncate the
432 // identifier.
433 if let Some(index) = name.find('\0') {
434 name = &name[..index];
435 }
436
437 // Any double quotes must be escaped
438 name.replace('"', "\"\"")
439}
440
441fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
442 channels.into_iter().fold(String::new(), |mut acc, chan| {
443 acc.push_str(r#"LISTEN ""#);
444 acc.push_str(&ident(chan.as_ref()));
445 acc.push_str(r#"";"#);
446 acc
447 })
448}
449
450#[test]
451fn test_build_listen_all_query_with_single_channel() {
452 let output = build_listen_all_query(&["test"]);
453 assert_eq!(output.as_str(), r#"LISTEN "test";"#);
454}
455
456#[test]
457fn test_build_listen_all_query_with_multiple_channels() {
458 let output = build_listen_all_query(&["channel.0", "channel.1"]);
459 assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
460}