1use crate::Result;
2use crate::message::Message;
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use tokio_util::codec::{Decoder, Encoder};
5use yykv_types::DsError;
6
7#[derive(Debug, Clone)]
8pub enum BackendMessage {
9 AuthenticationOk,
10 AuthenticationCleartextPassword,
11 AuthenticationMD5Password { salt: [u8; 4] },
12 ParameterStatus { name: String, value: String },
13 BackendKeyData { process_id: i32, secret_key: i32 },
14 ReadyForQuery { status: u8 },
15 RowDescription { fields: Vec<FieldDescription> },
16 DataRow { values: Vec<Option<Bytes>> },
17 CommandComplete { tag: String },
18 ErrorResponse { fields: Vec<(u8, String)> },
19 NoticeResponse { fields: Vec<(u8, String)> },
20 ParseComplete,
21 BindComplete,
22 NoData,
23 ParameterDescription { ids: Vec<u32> },
24 CloseComplete,
25}
26
27#[derive(Debug, Clone)]
28pub struct FieldDescription {
29 pub name: String,
30 pub table_oid: i32,
31 pub column_id: i16,
32 pub type_oid: i32,
33 pub type_size: i16,
34 pub type_modifier: i32,
35 pub format_code: i16,
36}
37
38#[derive(Debug, Clone, Default)]
39pub struct PgCodec;
40
41impl PgCodec {
42 pub fn new() -> Self {
43 Self
44 }
45}
46
47pub struct PgServerCodec;
48
49impl Default for PgServerCodec {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl PgServerCodec {
56 pub fn new() -> Self {
57 Self
58 }
59}
60
61impl Decoder for PgServerCodec {
62 type Item = Message;
63 type Error = DsError;
64
65 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
66 if src.is_empty() {
67 return Ok(None);
68 }
69
70 if src.len() >= 4 {
72 let len = i32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
73 if len == 8 && src.len() >= 8 {
74 let code = i32::from_be_bytes([src[4], src[5], src[6], src[7]]);
75 if code == 80877103 {
76 src.advance(8);
77 return Ok(Some(Message::SslRequest));
78 }
79 }
80
81 if len > 8 && src.len() >= len {
83 let protocol = i32::from_be_bytes([src[4], src[5], src[6], src[7]]);
84 if protocol == 196608 {
85 let mut body = src.split_to(len);
87 body.advance(8);
88 let mut params = Vec::new();
89 while body[0] != 0 {
90 let name = PgCodec::read_string(&mut body)?;
91 let value = PgCodec::read_string(&mut body)?;
92 params.push((name, value));
93 }
94 return Ok(Some(Message::Startup { params }));
95 }
96 }
97 }
98
99 if src.len() < 5 {
100 return Ok(None);
101 }
102
103 let tag = src[0];
104 let len = i32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
105
106 if src.len() < len + 1 {
107 return Ok(None);
108 }
109
110 let mut body = src.split_to(len + 1);
111 body.advance(5);
112
113 match tag {
114 b'p' => {
115 let pass = PgCodec::read_string(&mut body)?;
116 Ok(Some(Message::Password(pass)))
117 }
118 b'Q' => {
119 let query = PgCodec::read_string(&mut body)?;
120 Ok(Some(Message::Query(query)))
121 }
122 b'P' => {
123 let name = PgCodec::read_string(&mut body)?;
124 let query = PgCodec::read_string(&mut body)?;
125 let num_params = body.get_i16();
126 let mut param_types = Vec::with_capacity(num_params as usize);
127 for _ in 0..num_params {
128 param_types.push(body.get_u32());
129 }
130 Ok(Some(Message::Parse {
131 name,
132 query,
133 param_types,
134 }))
135 }
136 b'B' => {
137 let portal = PgCodec::read_string(&mut body)?;
138 let statement = PgCodec::read_string(&mut body)?;
139
140 let num_format_codes = body.get_i16();
141 let mut format_codes = Vec::with_capacity(num_format_codes as usize);
142 for _ in 0..num_format_codes {
143 format_codes.push(body.get_i16());
144 }
145
146 let num_params = body.get_i16();
147 let mut params = Vec::with_capacity(num_params as usize);
148 for i in 0..num_params {
149 let len = body.get_i32();
150 if len == -1 {
151 params.push(yykv_types::DsValue::Null);
152 } else {
153 let mut val_bytes = vec![0u8; len as usize];
154 body.copy_to_slice(&mut val_bytes);
155
156 let format = if format_codes.len() == 1 {
157 format_codes[0]
158 } else if format_codes.len() > i as usize {
159 format_codes[i as usize]
160 } else {
161 0 };
163
164 if format == 0 {
165 let s = String::from_utf8_lossy(&val_bytes).to_string();
167 params.push(yykv_types::DsValue::Text(s));
168 } else {
169 params.push(yykv_types::DsValue::Bytes(val_bytes.into()));
171 }
172 }
173 }
174 Ok(Some(Message::Bind {
175 portal,
176 statement,
177 params,
178 }))
179 }
180 b'D' => {
181 let target_type = body.get_u8();
182 let name = PgCodec::read_string(&mut body)?;
183 Ok(Some(Message::Describe { target_type, name }))
184 }
185 b'C' => {
186 let target_type = body.get_u8();
187 let name = PgCodec::read_string(&mut body)?;
188 Ok(Some(Message::Close { target_type, name }))
189 }
190 b'E' => {
191 let portal = PgCodec::read_string(&mut body)?;
192 let max_rows = body.get_i32();
193 Ok(Some(Message::Execute { portal, max_rows }))
194 }
195 b'S' => Ok(Some(Message::Sync)),
196 b'H' => Ok(Some(Message::Flush)),
197 b'X' => Ok(Some(Message::Terminate)),
198 _ => {
199 Ok(None)
201 }
202 }
203 }
204}
205
206impl Encoder<BackendMessage> for PgServerCodec {
207 type Error = DsError;
208
209 fn encode(&mut self, item: BackendMessage, dst: &mut BytesMut) -> Result<()> {
210 item.encode(dst);
211 Ok(())
212 }
213}
214
215impl BackendMessage {
216 pub fn encode(&self, dst: &mut BytesMut) {
217 match self {
218 BackendMessage::AuthenticationOk => {
219 dst.put_u8(b'R');
220 dst.put_i32(8);
221 dst.put_i32(0);
222 }
223 BackendMessage::ReadyForQuery { status } => {
224 dst.put_u8(b'Z');
225 dst.put_i32(5);
226 dst.put_u8(*status);
227 }
228 BackendMessage::CommandComplete { tag } => {
229 dst.put_u8(b'C');
230 let len = tag.len() + 1 + 4;
231 dst.put_i32(len as i32);
232 dst.put_slice(tag.as_bytes());
233 dst.put_u8(0);
234 }
235 BackendMessage::ParameterStatus { name, value } => {
236 dst.put_u8(b'S');
237 let len = name.len() + value.len() + 2 + 4;
238 dst.put_i32(len as i32);
239 dst.put_slice(name.as_bytes());
240 dst.put_u8(0);
241 dst.put_slice(value.as_bytes());
242 dst.put_u8(0);
243 }
244 BackendMessage::RowDescription { fields } => {
245 dst.put_u8(b'T');
246 let mut payload = BytesMut::new();
247 payload.put_i16(fields.len() as i16);
248 for field in fields {
249 payload.put_slice(field.name.as_bytes());
250 payload.put_u8(0);
251 payload.put_i32(field.table_oid);
252 payload.put_i16(field.column_id);
253 payload.put_i32(field.type_oid);
254 payload.put_i16(field.type_size);
255 payload.put_i32(field.type_modifier);
256 payload.put_i16(field.format_code);
257 }
258 dst.put_i32(payload.len() as i32 + 4);
259 dst.extend_from_slice(&payload);
260 }
261 BackendMessage::DataRow { values } => {
262 dst.put_u8(b'D');
263 let mut payload = BytesMut::new();
264 payload.put_i16(values.len() as i16);
265 for val in values {
266 match val {
267 Some(v) => {
268 payload.put_i32(v.len() as i32);
269 payload.put_slice(v);
270 }
271 None => {
272 payload.put_i32(-1);
273 }
274 }
275 }
276 dst.put_i32(payload.len() as i32 + 4);
277 dst.extend_from_slice(&payload);
278 }
279 BackendMessage::ErrorResponse { fields } => {
280 dst.put_u8(b'E');
281 let mut payload = BytesMut::new();
282 for (tag, msg) in fields {
283 payload.put_u8(*tag);
284 payload.put_slice(msg.as_bytes());
285 payload.put_u8(0);
286 }
287 payload.put_u8(0);
288 dst.put_i32(payload.len() as i32 + 4);
289 dst.extend_from_slice(&payload);
290 }
291 BackendMessage::ParameterDescription { ids } => {
292 dst.put_u8(b't');
293 let mut payload = BytesMut::new();
294 payload.put_i16(ids.len() as i16);
295 for &id in ids {
296 payload.put_u32(id);
297 }
298 dst.put_i32(payload.len() as i32 + 4);
299 dst.extend_from_slice(&payload);
300 }
301 BackendMessage::CloseComplete => {
302 dst.put_u8(b'3');
303 dst.put_i32(4);
304 }
305 BackendMessage::ParseComplete => {
306 dst.put_u8(b'1');
307 dst.put_i32(4);
308 }
309 BackendMessage::BindComplete => {
310 dst.put_u8(b'2');
311 dst.put_i32(4);
312 }
313 BackendMessage::NoData => {
314 dst.put_u8(b'n');
315 dst.put_i32(4);
316 }
317 _ => {
318 }
320 }
321 }
322}
323
324impl Decoder for PgCodec {
325 type Item = BackendMessage;
326 type Error = DsError;
327
328 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
329 if src.len() < 5 {
330 return Ok(None);
331 }
332
333 let tag = src[0];
334 let len = i32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
335
336 if src.len() < len + 1 {
337 return Ok(None);
338 }
339
340 let mut body = src.split_to(len + 1);
341 body.advance(5); let msg = match tag {
344 b'R' => Self::decode_authentication(&mut body)?,
345 b'S' => {
346 let name = Self::read_string(&mut body)?;
347 let value = Self::read_string(&mut body)?;
348 BackendMessage::ParameterStatus { name, value }
349 }
350 b'K' => {
351 let process_id = body.get_i32();
352 let secret_key = body.get_i32();
353 BackendMessage::BackendKeyData {
354 process_id,
355 secret_key,
356 }
357 }
358 b'Z' => {
359 let status = body.get_u8();
360 BackendMessage::ReadyForQuery { status }
361 }
362 b'T' => Self::decode_row_description(&mut body)?,
363 b'D' => Self::decode_data_row(&mut body)?,
364 b'C' => {
365 let tag = Self::read_string(&mut body)?;
366 BackendMessage::CommandComplete { tag }
367 }
368 b'E' => BackendMessage::ErrorResponse {
369 fields: Self::decode_error_notice(&mut body)?,
370 },
371 b'N' => BackendMessage::NoticeResponse {
372 fields: Self::decode_error_notice(&mut body)?,
373 },
374 b'1' => BackendMessage::ParseComplete,
375 b'2' => BackendMessage::BindComplete,
376 b'n' => BackendMessage::NoData,
377 b't' => {
378 let count = body.get_i16();
379 let mut ids = Vec::with_capacity(count as usize);
380 for _ in 0..count {
381 ids.push(body.get_u32());
382 }
383 BackendMessage::ParameterDescription { ids }
384 }
385 b'3' => BackendMessage::CloseComplete,
386 _ => {
387 return Err(DsError::protocol(format!(
388 "Unknown backend tag: {}",
389 tag as char
390 )));
391 }
392 };
393
394 Ok(Some(msg))
395 }
396}
397
398impl Encoder<Message> for PgCodec {
399 type Error = DsError;
400
401 fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<()> {
402 item.encode(dst);
403 Ok(())
404 }
405}
406
407impl PgCodec {
408 fn decode_authentication(body: &mut BytesMut) -> Result<BackendMessage> {
409 let auth_type = body.get_i32();
410 match auth_type {
411 0 => Ok(BackendMessage::AuthenticationOk),
412 3 => Ok(BackendMessage::AuthenticationCleartextPassword),
413 5 => {
414 let mut salt = [0u8; 4];
415 body.copy_to_slice(&mut salt);
416 Ok(BackendMessage::AuthenticationMD5Password { salt })
417 }
418 _ => Err(DsError::protocol(format!(
419 "Unsupported authentication type: {}",
420 auth_type
421 ))),
422 }
423 }
424
425 fn decode_row_description(body: &mut BytesMut) -> Result<BackendMessage> {
426 let count = body.get_i16();
427 let mut fields = Vec::with_capacity(count as usize);
428 for _ in 0..count {
429 fields.push(FieldDescription {
430 name: Self::read_string(body)?,
431 table_oid: body.get_i32(),
432 column_id: body.get_i16(),
433 type_oid: body.get_i32(),
434 type_size: body.get_i16(),
435 type_modifier: body.get_i32(),
436 format_code: body.get_i16(),
437 });
438 }
439 Ok(BackendMessage::RowDescription { fields })
440 }
441
442 fn decode_data_row(body: &mut BytesMut) -> Result<BackendMessage> {
443 let count = body.get_i16();
444 let mut values = Vec::with_capacity(count as usize);
445 for _ in 0..count {
446 let len = body.get_i32();
447 if len == -1 {
448 values.push(None);
449 } else {
450 values.push(Some(body.split_to(len as usize).freeze()));
451 }
452 }
453 Ok(BackendMessage::DataRow { values })
454 }
455
456 fn decode_error_notice(body: &mut BytesMut) -> Result<Vec<(u8, String)>> {
457 let mut fields = Vec::new();
458 loop {
459 let tag = body.get_u8();
460 if tag == 0 {
461 break;
462 }
463 fields.push((tag, Self::read_string(body)?));
464 }
465 Ok(fields)
466 }
467
468 fn read_string(body: &mut BytesMut) -> Result<String> {
469 let pos = body
470 .iter()
471 .position(|&b| b == 0)
472 .ok_or_else(|| DsError::protocol("String not null-terminated"))?;
473 let s = String::from_utf8_lossy(&body[..pos]).into_owned();
474 body.advance(pos + 1);
475 Ok(s)
476 }
477}