sqlx_etorreborre_postgres/connection/
stream.rs1use std::collections::BTreeMap;
2use std::ops::{Deref, DerefMut};
3use std::str::FromStr;
4
5use futures_channel::mpsc::UnboundedSender;
6use futures_util::SinkExt;
7use log::Level;
8use sqlx_core::bytes::{Buf, Bytes};
9
10use crate::connection::tls::MaybeUpgradeTls;
11use crate::error::Error;
12use crate::io::{Decode, Encode};
13use crate::message::{Message, MessageFormat, Notice, Notification, ParameterStatus};
14use crate::net::{self, BufferedSocket, Socket};
15use crate::{PgConnectOptions, PgDatabaseError, PgSeverity};
16
17pub struct PgStream {
27 inner: BufferedSocket<Box<dyn Socket>>,
30
31 pub(crate) notifications: Option<UnboundedSender<Notification>>,
35
36 pub(crate) parameter_statuses: BTreeMap<String, String>,
37
38 pub(crate) server_version_num: Option<u32>,
39}
40
41impl PgStream {
42 pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
43 let socket_future = match options.fetch_socket() {
44 Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
45 None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
46 };
47
48 let socket = socket_future.await?;
49
50 Ok(Self {
51 inner: BufferedSocket::new(socket),
52 notifications: None,
53 parameter_statuses: BTreeMap::default(),
54 server_version_num: None,
55 })
56 }
57
58 pub(crate) async fn send<'en, T>(&mut self, message: T) -> Result<(), Error>
59 where
60 T: Encode<'en>,
61 {
62 self.write(message);
63 self.flush().await?;
64 Ok(())
65 }
66
67 pub(crate) async fn recv_expect<'de, T: Decode<'de>>(
69 &mut self,
70 format: MessageFormat,
71 ) -> Result<T, Error> {
72 let message = self.recv().await?;
73
74 if message.format != format {
75 return Err(err_protocol!(
76 "expecting {:?} but received {:?}",
77 format,
78 message.format
79 ));
80 }
81
82 message.decode()
83 }
84
85 pub(crate) async fn recv_unchecked(&mut self) -> Result<Message, Error> {
86 let mut header: Bytes = self.inner.read(5).await?;
89
90 let format = MessageFormat::try_from_u8(header.get_u8())?;
91 let size = (header.get_u32() - 4) as usize;
92
93 let contents = self.inner.read(size).await?;
94
95 Ok(Message { format, contents })
96 }
97
98 pub(crate) async fn recv(&mut self) -> Result<Message, Error> {
101 loop {
102 let message = self.recv_unchecked().await?;
103
104 match message.format {
105 MessageFormat::ErrorResponse => {
106 return Err(PgDatabaseError(message.decode()?).into());
108 }
109
110 MessageFormat::NotificationResponse => {
111 if let Some(buffer) = &mut self.notifications {
112 let notification: Notification = message.decode()?;
113 let _ = buffer.send(notification).await;
114
115 continue;
116 }
117 }
118
119 MessageFormat::ParameterStatus => {
120 let ParameterStatus { name, value } = message.decode()?;
124 match name.as_str() {
127 "server_version" => {
128 self.server_version_num = parse_server_version(&value);
129 }
130 _ => {
131 self.parameter_statuses.insert(name, value);
132 }
133 }
134
135 continue;
136 }
137
138 MessageFormat::NoticeResponse => {
139 let notice: Notice = message.decode()?;
143
144 let (log_level, tracing_level) = match notice.severity() {
145 PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => {
146 (Level::Error, tracing::Level::ERROR)
147 }
148 PgSeverity::Warning => (Level::Warn, tracing::Level::WARN),
149 PgSeverity::Notice => (Level::Info, tracing::Level::INFO),
150 PgSeverity::Debug => (Level::Debug, tracing::Level::DEBUG),
151 PgSeverity::Info | PgSeverity::Log => (Level::Trace, tracing::Level::TRACE),
152 };
153
154 let log_is_enabled = log::log_enabled!(
155 target: "sqlx::postgres::notice",
156 log_level
157 ) || sqlx_core::private_tracing_dynamic_enabled!(
158 target: "sqlx::postgres::notice",
159 tracing_level
160 );
161 if log_is_enabled {
162 let message = format!("{}", notice.message());
163 sqlx_core::private_tracing_dynamic_event!(
164 target: "sqlx::postgres::notice",
165 tracing_level,
166 message
167 );
168 }
169
170 continue;
171 }
172
173 _ => {}
174 }
175
176 return Ok(message);
177 }
178 }
179}
180
181impl Deref for PgStream {
182 type Target = BufferedSocket<Box<dyn Socket>>;
183
184 #[inline]
185 fn deref(&self) -> &Self::Target {
186 &self.inner
187 }
188}
189
190impl DerefMut for PgStream {
191 #[inline]
192 fn deref_mut(&mut self) -> &mut Self::Target {
193 &mut self.inner
194 }
195}
196
197fn parse_server_version(s: &str) -> Option<u32> {
200 let mut parts = Vec::<u32>::with_capacity(3);
201
202 let mut from = 0;
203 let mut chs = s.char_indices().peekable();
204 while let Some((i, ch)) = chs.next() {
205 match ch {
206 '.' => {
207 if let Ok(num) = u32::from_str(&s[from..i]) {
208 parts.push(num);
209 from = i + 1;
210 } else {
211 break;
212 }
213 }
214 _ if ch.is_digit(10) => {
215 if chs.peek().is_none() {
216 if let Ok(num) = u32::from_str(&s[from..]) {
217 parts.push(num);
218 }
219 break;
220 }
221 }
222 _ => {
223 if let Ok(num) = u32::from_str(&s[from..i]) {
224 parts.push(num);
225 }
226 break;
227 }
228 };
229 }
230
231 let version_num = match parts.as_slice() {
232 [major, minor, rev] => (100 * major + minor) * 100 + rev,
233 [major, minor] if *major >= 10 => 100 * 100 * major + minor,
234 [major, minor] => (100 * major + minor) * 100,
235 [major] => 100 * 100 * major,
236 _ => return None,
237 };
238
239 Some(version_num)
240}
241
242#[cfg(test)]
243mod tests {
244 use super::parse_server_version;
245
246 #[test]
247 fn test_parse_server_version_num() {
248 assert_eq!(parse_server_version("9.6.1"), Some(90601));
250 assert_eq!(parse_server_version("10.1"), Some(100001));
252 assert_eq!(parse_server_version("9.6devel"), Some(90600));
254 assert_eq!(parse_server_version("10devel"), Some(100000));
256 assert_eq!(parse_server_version("13devel87"), Some(130000));
257 assert_eq!(parse_server_version("unknown"), None);
259 }
260}