Skip to main content

sentinel_driver/notify/
mod.rs

1pub mod channel;
2
3use crate::connection::stream::PgConnection;
4use crate::error::{Error, Result};
5use crate::protocol::backend::BackendMessage;
6use crate::protocol::frontend;
7
8/// A notification received from PostgreSQL via LISTEN/NOTIFY.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct Notification {
11    /// The PID of the backend process that sent the notification.
12    pub process_id: i32,
13    /// The channel name.
14    pub channel: String,
15    /// The payload string (may be empty).
16    pub payload: String,
17}
18
19/// Subscribe to a channel on the given connection.
20///
21/// Sends `LISTEN <channel>` and waits for confirmation.
22pub(crate) async fn listen(conn: &mut PgConnection, channel: &str) -> Result<()> {
23    // Validate channel name (prevent SQL injection)
24    validate_channel_name(channel)?;
25
26    let sql = format!("LISTEN {}", quote_identifier(channel));
27    frontend::query(conn.write_buf(), &sql);
28    conn.send().await?;
29
30    // Expect CommandComplete + ReadyForQuery
31    loop {
32        match conn.recv().await? {
33            BackendMessage::ReadyForQuery { .. } => return Ok(()),
34            BackendMessage::ErrorResponse { fields } => {
35                drain_until_ready(conn).await.ok();
36                return Err(Error::server(
37                    fields.severity,
38                    fields.code,
39                    fields.message,
40                    fields.detail,
41                    fields.hint,
42                    fields.position,
43                ));
44            }
45            _ => {}
46        }
47    }
48}
49
50/// Unsubscribe from a channel.
51pub(crate) async fn unlisten(conn: &mut PgConnection, channel: &str) -> Result<()> {
52    validate_channel_name(channel)?;
53
54    let sql = format!("UNLISTEN {}", quote_identifier(channel));
55    frontend::query(conn.write_buf(), &sql);
56    conn.send().await?;
57
58    loop {
59        match conn.recv().await? {
60            BackendMessage::ReadyForQuery { .. } => return Ok(()),
61            BackendMessage::ErrorResponse { fields } => {
62                drain_until_ready(conn).await.ok();
63                return Err(Error::server(
64                    fields.severity,
65                    fields.code,
66                    fields.message,
67                    fields.detail,
68                    fields.hint,
69                    fields.position,
70                ));
71            }
72            _ => {}
73        }
74    }
75}
76
77/// Unsubscribe from all channels.
78pub(crate) async fn unlisten_all(conn: &mut PgConnection) -> Result<()> {
79    frontend::query(conn.write_buf(), "UNLISTEN *");
80    conn.send().await?;
81
82    loop {
83        match conn.recv().await? {
84            BackendMessage::ReadyForQuery { .. } => return Ok(()),
85            BackendMessage::ErrorResponse { fields } => {
86                drain_until_ready(conn).await.ok();
87                return Err(Error::server(
88                    fields.severity,
89                    fields.code,
90                    fields.message,
91                    fields.detail,
92                    fields.hint,
93                    fields.position,
94                ));
95            }
96            _ => {}
97        }
98    }
99}
100
101/// Send a notification on a channel.
102pub(crate) async fn notify(conn: &mut PgConnection, channel: &str, payload: &str) -> Result<()> {
103    validate_channel_name(channel)?;
104
105    // Use pg_notify() function to safely pass the payload as a parameter
106    let sql = format!(
107        "SELECT pg_notify({}, {})",
108        quote_literal(channel),
109        quote_literal(payload)
110    );
111    frontend::query(conn.write_buf(), &sql);
112    conn.send().await?;
113
114    loop {
115        match conn.recv().await? {
116            BackendMessage::ReadyForQuery { .. } => return Ok(()),
117            BackendMessage::ErrorResponse { fields } => {
118                drain_until_ready(conn).await.ok();
119                return Err(Error::server(
120                    fields.severity,
121                    fields.code,
122                    fields.message,
123                    fields.detail,
124                    fields.hint,
125                    fields.position,
126                ));
127            }
128            _ => {}
129        }
130    }
131}
132
133/// Wait for the next notification on the connection.
134///
135/// This blocks until a NotificationResponse is received.
136/// Other messages (ParameterStatus, NoticeResponse) are silently consumed.
137pub(crate) async fn wait_for_notification(conn: &mut PgConnection) -> Result<Notification> {
138    loop {
139        match conn.recv().await? {
140            BackendMessage::NotificationResponse {
141                process_id,
142                channel,
143                payload,
144            } => {
145                return Ok(Notification {
146                    process_id,
147                    channel,
148                    payload,
149                });
150            }
151            BackendMessage::ErrorResponse { fields } => {
152                return Err(Error::server(
153                    fields.severity,
154                    fields.code,
155                    fields.message,
156                    fields.detail,
157                    fields.hint,
158                    fields.position,
159                ));
160            }
161            _ => {}
162        }
163    }
164}
165
166/// Validate that a channel name is safe to use in SQL.
167pub fn validate_channel_name(name: &str) -> Result<()> {
168    if name.is_empty() {
169        return Err(Error::Config("channel name cannot be empty".into()));
170    }
171    if name.len() > 63 {
172        return Err(Error::Config(
173            "channel name exceeds 63 character limit".into(),
174        ));
175    }
176    Ok(())
177}
178
179/// Quote an identifier for safe use in SQL (double-quote escaping).
180pub fn quote_identifier(name: &str) -> String {
181    format!("\"{}\"", name.replace('"', "\"\""))
182}
183
184/// Quote a string literal for safe use in SQL (single-quote escaping).
185pub fn quote_literal(val: &str) -> String {
186    format!("'{}'", val.replace('\'', "''"))
187}
188
189async fn drain_until_ready(conn: &mut PgConnection) -> Result<()> {
190    loop {
191        if let BackendMessage::ReadyForQuery { .. } = conn.recv().await? {
192            return Ok(());
193        }
194    }
195}