sqlx_postgres/connection/
stream.rs1use std::collections::BTreeMap;
2use std::ops::{ControlFlow, Deref, DerefMut};
3use std::str::FromStr;
4
5use futures_channel::mpsc::UnboundedSender;
6use futures_util::SinkExt;
7use log::Level;
8use sqlx_core::bytes::Buf;
9
10use crate::connection::tls::MaybeUpgradeTls;
11use crate::error::Error;
12use crate::message::{
13 BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification,
14 ParameterStatus, ReceivedMessage,
15};
16use crate::net::{self, BufferedSocket, Socket};
17use crate::{PgConnectOptions, PgDatabaseError, PgSeverity};
18
19pub struct PgStream {
29 inner: BufferedSocket<Box<dyn Socket>>,
32
33 pub(crate) notifications: Option<UnboundedSender<Notification>>,
37
38 pub(crate) parameter_statuses: BTreeMap<String, String>,
39
40 pub(crate) server_version_num: Option<u32>,
41}
42
43impl PgStream {
44 pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
45 let socket_result = match options.fetch_socket() {
46 Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
47 None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
48 };
49
50 let socket = socket_result?;
51
52 Ok(Self {
53 inner: BufferedSocket::new(socket),
54 notifications: None,
55 parameter_statuses: BTreeMap::default(),
56 server_version_num: None,
57 })
58 }
59
60 #[inline(always)]
61 pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> {
62 self.write(EncodeMessage(message))
63 }
64
65 pub(crate) async fn send<T>(&mut self, message: T) -> Result<(), Error>
66 where
67 T: FrontendMessage,
68 {
69 self.write_msg(message)?;
70 self.flush().await?;
71 Ok(())
72 }
73
74 pub(crate) async fn recv_expect<B: BackendMessage>(&mut self) -> Result<B, Error> {
76 self.recv().await?.decode()
77 }
78
79 pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
80 self.inner
83 .try_read(|buf| {
84 let Some(mut header) = buf.get(..5) else {
87 return Ok(ControlFlow::Continue(5));
88 };
89
90 let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
91
92 let message_len = header.get_u32() as usize;
93
94 let expected_len = message_len
95 .checked_add(1)
96 .ok_or_else(|| {
98 err_protocol!("message_len + 1 overflows usize: {message_len}")
99 })?;
100
101 if buf.len() < expected_len {
102 return Ok(ControlFlow::Continue(expected_len));
103 }
104
105 buf.advance(1);
109
110 let mut contents = buf.split_to(message_len).freeze();
112
113 contents.advance(4);
115
116 Ok(ControlFlow::Break(ReceivedMessage { format, contents }))
117 })
118 .await
119 }
120
121 pub(crate) async fn recv(&mut self) -> Result<ReceivedMessage, Error> {
124 loop {
125 let message = self.recv_unchecked().await?;
126
127 match message.format {
128 BackendMessageFormat::ErrorResponse => {
129 return Err(message.decode::<PgDatabaseError>()?.into());
131 }
132
133 BackendMessageFormat::NotificationResponse => {
134 if let Some(buffer) = &mut self.notifications {
135 let notification: Notification = message.decode()?;
136 let _ = buffer.send(notification).await;
137
138 continue;
139 }
140 }
141
142 BackendMessageFormat::ParameterStatus => {
143 let ParameterStatus { name, value } = message.decode()?;
147 match name.as_str() {
150 "server_version" => {
151 self.server_version_num = parse_server_version(&value);
152 }
153 _ => {
154 self.parameter_statuses.insert(name, value);
155 }
156 }
157
158 continue;
159 }
160
161 BackendMessageFormat::NoticeResponse => {
162 let notice: Notice = message.decode()?;
166
167 let (log_level, tracing_level) = match notice.severity() {
168 PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => {
169 (Level::Error, tracing::Level::ERROR)
170 }
171 PgSeverity::Warning => (Level::Warn, tracing::Level::WARN),
172 PgSeverity::Notice => (Level::Info, tracing::Level::INFO),
173 PgSeverity::Debug => (Level::Debug, tracing::Level::DEBUG),
174 PgSeverity::Info | PgSeverity::Log => (Level::Trace, tracing::Level::TRACE),
175 };
176
177 let log_is_enabled = log::log_enabled!(
178 target: "sqlx::postgres::notice",
179 log_level
180 ) || sqlx_core::private_tracing_dynamic_enabled!(
181 target: "sqlx::postgres::notice",
182 tracing_level
183 );
184 if log_is_enabled {
185 sqlx_core::private_tracing_dynamic_event!(
186 target: "sqlx::postgres::notice",
187 tracing_level,
188 message = notice.message()
189 );
190 }
191
192 continue;
193 }
194
195 _ => {}
196 }
197
198 return Ok(message);
199 }
200 }
201}
202
203impl Deref for PgStream {
204 type Target = BufferedSocket<Box<dyn Socket>>;
205
206 #[inline]
207 fn deref(&self) -> &Self::Target {
208 &self.inner
209 }
210}
211
212impl DerefMut for PgStream {
213 #[inline]
214 fn deref_mut(&mut self) -> &mut Self::Target {
215 &mut self.inner
216 }
217}
218
219fn parse_server_version(s: &str) -> Option<u32> {
222 let mut parts = Vec::<u32>::with_capacity(3);
223
224 let mut from = 0;
225 let mut chs = s.char_indices().peekable();
226 while let Some((i, ch)) = chs.next() {
227 match ch {
228 '.' => {
229 if let Ok(num) = u32::from_str(&s[from..i]) {
230 parts.push(num);
231 from = i + 1;
232 } else {
233 break;
234 }
235 }
236 _ if ch.is_ascii_digit() => {
237 if chs.peek().is_none() {
238 if let Ok(num) = u32::from_str(&s[from..]) {
239 parts.push(num);
240 }
241 break;
242 }
243 }
244 _ => {
245 if let Ok(num) = u32::from_str(&s[from..i]) {
246 parts.push(num);
247 }
248 break;
249 }
250 };
251 }
252
253 let version_num = match parts.as_slice() {
254 [major, minor, rev] => (100 * major + minor) * 100 + rev,
255 [major, minor] if *major >= 10 => 100 * 100 * major + minor,
256 [major, minor] => (100 * major + minor) * 100,
257 [major] => 100 * 100 * major,
258 _ => return None,
259 };
260
261 Some(version_num)
262}
263
264#[cfg(test)]
265mod tests {
266 use super::parse_server_version;
267
268 #[test]
269 fn test_parse_server_version_num() {
270 assert_eq!(parse_server_version("9.6.1"), Some(90601));
272 assert_eq!(parse_server_version("10.1"), Some(100001));
274 assert_eq!(parse_server_version("9.6devel"), Some(90600));
276 assert_eq!(parse_server_version("10devel"), Some(100000));
278 assert_eq!(parse_server_version("13devel87"), Some(130000));
279 assert_eq!(parse_server_version("unknown"), None);
281 }
282}