sqlx_postgres/connection/
stream.rs

1use std::collections::BTreeMap;
2use std::ops::{ControlFlow, Deref, DerefMut};
3use std::str::FromStr;
4
5use futures_channel::mpsc::UnboundedSender;
6use futures_util::SinkExt;
7use log::Level;
8use sqlx_core::bytes::Buf;
9
10use crate::connection::tls::MaybeUpgradeTls;
11use crate::error::Error;
12use crate::message::{
13    BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification,
14    ParameterStatus, ReceivedMessage,
15};
16use crate::net::{self, BufferedSocket, Socket};
17use crate::{PgConnectOptions, PgDatabaseError, PgSeverity};
18
19// the stream is a separate type from the connection to uphold the invariant where an instantiated
20// [PgConnection] is a **valid** connection to postgres
21
22// when a new connection is asked for, we work directly on the [PgStream] type until the
23// connection is fully established
24
25// in other words, `self` in any PgConnection method is a live connection to postgres that
26// is fully prepared to receive queries
27
28pub struct PgStream {
29    // A trait object is okay here as the buffering amortizes the overhead of both the dynamic
30    // function call as well as the syscall.
31    inner: BufferedSocket<Box<dyn Socket>>,
32
33    // buffer of unreceived notification messages from `PUBLISH`
34    // this is set when creating a PgListener and only written to if that listener is
35    // re-used for query execution in-between receiving messages
36    pub(crate) notifications: Option<UnboundedSender<Notification>>,
37
38    pub(crate) parameter_statuses: BTreeMap<String, String>,
39
40    pub(crate) server_version_num: Option<u32>,
41}
42
43impl PgStream {
44    pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
45        let socket_result = match options.fetch_socket() {
46            Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
47            None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
48        };
49
50        let socket = socket_result?;
51
52        Ok(Self {
53            inner: BufferedSocket::new(socket),
54            notifications: None,
55            parameter_statuses: BTreeMap::default(),
56            server_version_num: None,
57        })
58    }
59
60    #[inline(always)]
61    pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> {
62        self.write(EncodeMessage(message))
63    }
64
65    pub(crate) async fn send<T>(&mut self, message: T) -> Result<(), Error>
66    where
67        T: FrontendMessage,
68    {
69        self.write_msg(message)?;
70        self.flush().await?;
71        Ok(())
72    }
73
74    // Expect a specific type and format
75    pub(crate) async fn recv_expect<B: BackendMessage>(&mut self) -> Result<B, Error> {
76        self.recv().await?.decode()
77    }
78
79    pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
80        // NOTE: to not break everything, this should be cancel-safe;
81        // DO NOT modify `buf` unless a full message has been read
82        self.inner
83            .try_read(|buf| {
84                // all packets in postgres start with a 5-byte header
85                // this header contains the message type and the total length of the message
86                let Some(mut header) = buf.get(..5) else {
87                    return Ok(ControlFlow::Continue(5));
88                };
89
90                let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
91
92                let message_len = header.get_u32() as usize;
93
94                let expected_len = message_len
95                    .checked_add(1)
96                    // this shouldn't really happen but is mostly a sanity check
97                    .ok_or_else(|| {
98                        err_protocol!("message_len + 1 overflows usize: {message_len}")
99                    })?;
100
101                if buf.len() < expected_len {
102                    return Ok(ControlFlow::Continue(expected_len));
103                }
104
105                // `buf` SHOULD NOT be modified ABOVE this line
106
107                // pop off the format code since it's not counted in `message_len`
108                buf.advance(1);
109
110                // consume the message, including the length prefix
111                let mut contents = buf.split_to(message_len).freeze();
112
113                // cut off the length prefix
114                contents.advance(4);
115
116                Ok(ControlFlow::Break(ReceivedMessage { format, contents }))
117            })
118            .await
119    }
120
121    // Get the next message from the server
122    // May wait for more data from the server
123    pub(crate) async fn recv(&mut self) -> Result<ReceivedMessage, Error> {
124        loop {
125            let message = self.recv_unchecked().await?;
126
127            match message.format {
128                BackendMessageFormat::ErrorResponse => {
129                    // An error returned from the database server.
130                    return Err(message.decode::<PgDatabaseError>()?.into());
131                }
132
133                BackendMessageFormat::NotificationResponse => {
134                    if let Some(buffer) = &mut self.notifications {
135                        let notification: Notification = message.decode()?;
136                        let _ = buffer.send(notification).await;
137
138                        continue;
139                    }
140                }
141
142                BackendMessageFormat::ParameterStatus => {
143                    // informs the frontend about the current (initial)
144                    // setting of backend parameters
145
146                    let ParameterStatus { name, value } = message.decode()?;
147                    // TODO: handle `client_encoding`, `DateStyle` change
148
149                    match name.as_str() {
150                        "server_version" => {
151                            self.server_version_num = parse_server_version(&value);
152                        }
153                        _ => {
154                            self.parameter_statuses.insert(name, value);
155                        }
156                    }
157
158                    continue;
159                }
160
161                BackendMessageFormat::NoticeResponse => {
162                    // do we need this to be more configurable?
163                    // if you are reading this comment and think so, open an issue
164
165                    let notice: Notice = message.decode()?;
166
167                    let (log_level, tracing_level) = match notice.severity() {
168                        PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => {
169                            (Level::Error, tracing::Level::ERROR)
170                        }
171                        PgSeverity::Warning => (Level::Warn, tracing::Level::WARN),
172                        PgSeverity::Notice => (Level::Info, tracing::Level::INFO),
173                        PgSeverity::Debug => (Level::Debug, tracing::Level::DEBUG),
174                        PgSeverity::Info | PgSeverity::Log => (Level::Trace, tracing::Level::TRACE),
175                    };
176
177                    let log_is_enabled = log::log_enabled!(
178                        target: "sqlx::postgres::notice",
179                        log_level
180                    ) || sqlx_core::private_tracing_dynamic_enabled!(
181                        target: "sqlx::postgres::notice",
182                        tracing_level
183                    );
184                    if log_is_enabled {
185                        sqlx_core::private_tracing_dynamic_event!(
186                            target: "sqlx::postgres::notice",
187                            tracing_level,
188                            message = notice.message()
189                        );
190                    }
191
192                    continue;
193                }
194
195                _ => {}
196            }
197
198            return Ok(message);
199        }
200    }
201}
202
203impl Deref for PgStream {
204    type Target = BufferedSocket<Box<dyn Socket>>;
205
206    #[inline]
207    fn deref(&self) -> &Self::Target {
208        &self.inner
209    }
210}
211
212impl DerefMut for PgStream {
213    #[inline]
214    fn deref_mut(&mut self) -> &mut Self::Target {
215        &mut self.inner
216    }
217}
218
219// reference:
220// https://github.com/postgres/postgres/blob/6feebcb6b44631c3dc435e971bd80c2dd218a5ab/src/interfaces/libpq/fe-exec.c#L1030-L1065
221fn parse_server_version(s: &str) -> Option<u32> {
222    let mut parts = Vec::<u32>::with_capacity(3);
223
224    let mut from = 0;
225    let mut chs = s.char_indices().peekable();
226    while let Some((i, ch)) = chs.next() {
227        match ch {
228            '.' => {
229                if let Ok(num) = u32::from_str(&s[from..i]) {
230                    parts.push(num);
231                    from = i + 1;
232                } else {
233                    break;
234                }
235            }
236            _ if ch.is_ascii_digit() => {
237                if chs.peek().is_none() {
238                    if let Ok(num) = u32::from_str(&s[from..]) {
239                        parts.push(num);
240                    }
241                    break;
242                }
243            }
244            _ => {
245                if let Ok(num) = u32::from_str(&s[from..i]) {
246                    parts.push(num);
247                }
248                break;
249            }
250        };
251    }
252
253    let version_num = match parts.as_slice() {
254        [major, minor, rev] => (100 * major + minor) * 100 + rev,
255        [major, minor] if *major >= 10 => 100 * 100 * major + minor,
256        [major, minor] => (100 * major + minor) * 100,
257        [major] => 100 * 100 * major,
258        _ => return None,
259    };
260
261    Some(version_num)
262}
263
264#[cfg(test)]
265mod tests {
266    use super::parse_server_version;
267
268    #[test]
269    fn test_parse_server_version_num() {
270        // old style
271        assert_eq!(parse_server_version("9.6.1"), Some(90601));
272        // new style
273        assert_eq!(parse_server_version("10.1"), Some(100001));
274        // old style without minor version
275        assert_eq!(parse_server_version("9.6devel"), Some(90600));
276        // new style without minor version, e.g.  */
277        assert_eq!(parse_server_version("10devel"), Some(100000));
278        assert_eq!(parse_server_version("13devel87"), Some(130000));
279        // unknown
280        assert_eq!(parse_server_version("unknown"), None);
281    }
282}