1use core::net::SocketAddr;
4
5use fallible_iterator::FallibleIterator;
6use postgres_protocol::{
7 authentication::{self, sasl},
8 message::{backend, frontend},
9};
10use xitca_io::{bytes::BytesMut, io::AsyncIo};
11
12use super::{
13 config::{Config, SslMode, SslNegotiation},
14 driver::generic::GenericDriver,
15 error::{AuthenticationError, Error},
16};
17
18#[derive(Copy, Clone, Debug, Eq, PartialEq)]
20#[non_exhaustive]
21pub enum TargetSessionAttrs {
22 Any,
24 ReadWrite,
26 ReadOnly,
28}
29
30#[derive(Clone)]
32pub struct Session {
33 pub(crate) id: i32,
34 pub(crate) key: i32,
35 pub(crate) info: ConnectInfo,
36}
37
38#[derive(Clone, Default)]
39pub(crate) struct ConnectInfo {
40 pub(crate) addr: Addr,
41 pub(crate) ssl_mode: SslMode,
42 pub(crate) ssl_negotiation: SslNegotiation,
43}
44
45impl ConnectInfo {
46 pub(crate) fn new(addr: Addr, ssl_mode: SslMode, ssl_negotiation: SslNegotiation) -> Self {
47 Self {
48 addr,
49 ssl_mode,
50 ssl_negotiation,
51 }
52 }
53}
54
55#[derive(Clone, Default)]
56pub(crate) enum Addr {
57 Tcp(Box<str>, SocketAddr),
58 #[cfg(unix)]
59 Unix(Box<str>, std::path::PathBuf),
60 #[cfg(feature = "quic")]
61 Quic(Box<str>, SocketAddr),
62 #[default]
64 None,
65}
66
67impl Session {
68 fn new(info: ConnectInfo) -> Self {
69 Self { id: 0, key: 0, info }
70 }
71}
72
73impl Session {
74 #[allow(clippy::needless_pass_by_ref_mut)] #[cold]
76 #[inline(never)]
77 pub(super) async fn prepare_session<Io>(
78 info: ConnectInfo,
79 drv: &mut GenericDriver<Io>,
80 cfg: &Config,
81 ) -> Result<Self, Error>
82 where
83 Io: AsyncIo + Send,
84 {
85 let mut buf = BytesMut::new();
86
87 auth(drv, cfg, &mut buf).await?;
88
89 let mut session = Session::new(info);
90
91 loop {
92 match drv.recv().await? {
93 backend::Message::ReadyForQuery(_) => break,
94 backend::Message::BackendKeyData(body) => {
95 session.id = body.process_id();
96 session.key = body.secret_key();
97 }
98 backend::Message::ParameterStatus(body) => {
99 let _name = body.name()?;
101 let _value = body.value()?;
102 }
103 backend::Message::ErrorResponse(body) => return Err(Error::db(body.fields())),
104 backend::Message::NoticeResponse(_) => {
105 }
107 _ => return Err(Error::unexpected()),
108 }
109 }
110
111 if !matches!(cfg.get_target_session_attrs(), TargetSessionAttrs::Any) {
112 frontend::query("SHOW transaction_read_only", &mut buf)?;
113 let msg = buf.split();
114 drv.send(msg).await?;
115 loop {
117 match drv.recv().await? {
118 backend::Message::DataRow(body) => {
119 let range = body.ranges().next()?.flatten().ok_or(Error::todo())?;
120 let slice = &body.buffer()[range.start..range.end];
121 match (slice, cfg.get_target_session_attrs()) {
122 (b"on", TargetSessionAttrs::ReadWrite) => return Err(Error::todo()),
123 (b"off", TargetSessionAttrs::ReadOnly) => return Err(Error::todo()),
124 _ => {}
125 }
126 }
127 backend::Message::RowDescription(_) | backend::Message::CommandComplete(_) => {}
128 backend::Message::EmptyQueryResponse | backend::Message::ReadyForQuery(_) => break,
129 _ => return Err(Error::unexpected()),
130 }
131 }
132 }
133
134 Ok(session)
135 }
136}
137
138#[cold]
139#[inline(never)]
140async fn auth<Io>(drv: &mut GenericDriver<Io>, cfg: &Config, buf: &mut BytesMut) -> Result<(), Error>
141where
142 Io: AsyncIo + Send,
143{
144 let mut params = vec![("client_encoding", "UTF8")];
145 if let Some(user) = &cfg.user {
146 params.push(("user", &**user));
147 }
148 if let Some(dbname) = &cfg.dbname {
149 params.push(("database", &**dbname));
150 }
151 if let Some(options) = &cfg.options {
152 params.push(("options", &**options));
153 }
154 if let Some(application_name) = &cfg.application_name {
155 params.push(("application_name", &**application_name));
156 }
157
158 frontend::startup_message(params, buf)?;
159 let msg = buf.split();
160 drv.send(msg).await?;
161
162 loop {
163 match drv.recv().await? {
164 backend::Message::AuthenticationOk => return Ok(()),
165 backend::Message::AuthenticationCleartextPassword => {
166 let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?;
167 send_pass(drv, pass, buf).await?;
168 }
169 backend::Message::AuthenticationMd5Password(body) => {
170 let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?;
171 let user = cfg.get_user().ok_or(AuthenticationError::MissingUserName)?.as_bytes();
172 let pass = authentication::md5_hash(user, pass, body.salt());
173 send_pass(drv, pass, buf).await?;
174 }
175 backend::Message::AuthenticationSasl(body) => {
176 let pass = cfg.get_password().ok_or(AuthenticationError::MissingPassWord)?;
177
178 let mut is_scram = false;
179 let mut is_scram_plus = false;
180 let mut mechanisms = body.mechanisms();
181
182 while let Some(mechanism) = mechanisms.next()? {
183 match mechanism {
184 sasl::SCRAM_SHA_256 => is_scram = true,
185 sasl::SCRAM_SHA_256_PLUS => is_scram_plus = true,
186 _ => {}
187 }
188 }
189
190 let (channel_binding, mechanism) = match (is_scram_plus, is_scram) {
191 (true, is_scram) => {
192 let buf = cfg.get_tls_server_end_point();
193 match (buf, is_scram) {
194 (Some(buf), _) => (
195 sasl::ChannelBinding::tls_server_end_point(buf.to_owned()),
196 sasl::SCRAM_SHA_256_PLUS,
197 ),
198 (None, true) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
199 _ => return Err(Error::todo()),
202 }
203 }
204 (false, true) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
205 (false, false) => return Err(Error::todo()),
207 };
208
209 let mut scram = sasl::ScramSha256::new(pass, channel_binding);
210
211 frontend::sasl_initial_response(mechanism, scram.message(), buf)?;
212 let msg = buf.split();
213 drv.send(msg).await?;
214
215 match drv.recv().await? {
216 backend::Message::AuthenticationSaslContinue(body) => {
217 scram.update(body.data())?;
218 frontend::sasl_response(scram.message(), buf)?;
219 let msg = buf.split();
220 drv.send(msg).await?;
221 }
222 _ => return Err(Error::todo()),
223 }
224
225 match drv.recv().await? {
226 backend::Message::AuthenticationSaslFinal(body) => scram.finish(body.data())?,
227 _ => return Err(Error::todo()),
228 }
229 }
230 backend::Message::ErrorResponse(_) => return Err(Error::from(AuthenticationError::WrongPassWord)),
231 _ => {}
232 }
233 }
234}
235
236async fn send_pass<Io>(drv: &mut GenericDriver<Io>, pass: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), Error>
237where
238 Io: AsyncIo + Send,
239{
240 frontend::password_message(pass.as_ref(), buf)?;
241 let msg = buf.split();
242 drv.send(msg).await
243}