1#![allow(clippy::cast_possible_truncation)]
6
7use super::messages::{
8 BackendMessage, ErrorFields, FieldDescription, TransactionStatus, auth_type, backend_type,
9};
10use std::error::Error as StdError;
11use std::fmt;
12
13#[derive(Debug)]
15pub enum ProtocolError {
16 Incomplete,
18 InvalidLength { length: i32 },
20 MessageTooLarge { length: usize, max: usize },
22 UnknownMessageType(u8),
24 Utf8(std::string::FromUtf8Error),
26 UnexpectedEof,
28 InvalidField(&'static str),
30}
31
32impl fmt::Display for ProtocolError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 ProtocolError::Incomplete => write!(f, "incomplete message"),
36 ProtocolError::InvalidLength { length } => {
37 write!(f, "invalid message length: {}", length)
38 }
39 ProtocolError::MessageTooLarge { length, max } => {
40 write!(f, "message too large: {} > {}", length, max)
41 }
42 ProtocolError::UnknownMessageType(ty) => {
43 write!(f, "unknown message type: 0x{:02x}", ty)
44 }
45 ProtocolError::Utf8(err) => write!(f, "utf-8 error: {}", err),
46 ProtocolError::UnexpectedEof => write!(f, "unexpected end of buffer"),
47 ProtocolError::InvalidField(msg) => write!(f, "invalid field: {}", msg),
48 }
49 }
50}
51
52impl StdError for ProtocolError {}
53
54impl From<std::string::FromUtf8Error> for ProtocolError {
55 fn from(err: std::string::FromUtf8Error) -> Self {
56 ProtocolError::Utf8(err)
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct MessageReader {
63 buf: Vec<u8>,
64 max_message_size: usize,
65}
66
67impl Default for MessageReader {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl MessageReader {
74 pub fn new() -> Self {
76 Self::with_max_size(8 * 1024 * 1024)
77 }
78
79 pub fn with_max_size(max_message_size: usize) -> Self {
81 Self {
82 buf: Vec::new(),
83 max_message_size,
84 }
85 }
86
87 pub fn buffered_len(&self) -> usize {
89 self.buf.len()
90 }
91
92 pub fn push(&mut self, data: &[u8]) {
99 self.buf.extend_from_slice(data);
100 }
101
102 pub fn feed(&mut self, data: &[u8]) -> Result<Vec<BackendMessage>, ProtocolError> {
104 self.buf.extend_from_slice(data);
105
106 let mut messages = Vec::new();
107 while let Some(msg) = self.next_message()? {
108 messages.push(msg);
109 }
110 Ok(messages)
111 }
112
113 pub fn next_message(&mut self) -> Result<Option<BackendMessage>, ProtocolError> {
115 if self.buf.len() < 5 {
116 return Ok(None);
117 }
118
119 let length = i32::from_be_bytes([self.buf[1], self.buf[2], self.buf[3], self.buf[4]]);
120 if length < 4 {
121 return Err(ProtocolError::InvalidLength { length });
122 }
123
124 let total_len = length as usize + 1;
125 if total_len > self.max_message_size {
126 return Err(ProtocolError::MessageTooLarge {
127 length: total_len,
128 max: self.max_message_size,
129 });
130 }
131
132 if self.buf.len() < total_len {
133 return Ok(None);
134 }
135
136 let frame = self.buf[..total_len].to_vec();
137 self.buf.drain(..total_len);
138 Ok(Some(Self::parse_message(&frame)?))
139 }
140
141 pub fn parse_message(frame: &[u8]) -> Result<BackendMessage, ProtocolError> {
143 if frame.len() < 5 {
144 return Err(ProtocolError::Incomplete);
145 }
146
147 let ty = frame[0];
148 let length = i32::from_be_bytes([frame[1], frame[2], frame[3], frame[4]]);
149 if length < 4 {
150 return Err(ProtocolError::InvalidLength { length });
151 }
152
153 let total_len = length as usize + 1;
154 if frame.len() < total_len {
155 return Err(ProtocolError::Incomplete);
156 }
157
158 let payload = &frame[5..total_len];
159 let mut cur = Cursor::new(payload);
160
161 match ty {
162 backend_type::AUTHENTICATION => parse_authentication(&mut cur),
163 backend_type::BACKEND_KEY_DATA => parse_backend_key_data(&mut cur),
164 backend_type::PARAMETER_STATUS => parse_parameter_status(&mut cur),
165 backend_type::READY_FOR_QUERY => parse_ready_for_query(&mut cur),
166 backend_type::ROW_DESCRIPTION => parse_row_description(&mut cur),
167 backend_type::DATA_ROW => parse_data_row(&mut cur),
168 backend_type::COMMAND_COMPLETE => parse_command_complete(&mut cur),
169 backend_type::EMPTY_QUERY => Ok(BackendMessage::EmptyQueryResponse),
170 backend_type::PARSE_COMPLETE => Ok(BackendMessage::ParseComplete),
171 backend_type::BIND_COMPLETE => Ok(BackendMessage::BindComplete),
172 backend_type::CLOSE_COMPLETE => Ok(BackendMessage::CloseComplete),
173 backend_type::PARAMETER_DESCRIPTION => parse_parameter_description(&mut cur),
174 backend_type::NO_DATA => Ok(BackendMessage::NoData),
175 backend_type::PORTAL_SUSPENDED => Ok(BackendMessage::PortalSuspended),
176 backend_type::ERROR_RESPONSE => parse_error_response(&mut cur, true),
177 backend_type::NOTICE_RESPONSE => parse_error_response(&mut cur, false),
178 backend_type::COPY_IN_RESPONSE => parse_copy_in_response(&mut cur),
179 backend_type::COPY_OUT_RESPONSE => parse_copy_out_response(&mut cur),
180 backend_type::COPY_BOTH_RESPONSE => parse_copy_both_response(&mut cur),
181 backend_type::COPY_DATA => Ok(BackendMessage::CopyData(cur.take_remaining())),
182 backend_type::COPY_DONE => Ok(BackendMessage::CopyDone),
183 backend_type::NOTIFICATION_RESPONSE => parse_notification_response(&mut cur),
184 backend_type::FUNCTION_CALL_RESPONSE => parse_function_call_response(&mut cur),
185 backend_type::NEGOTIATE_PROTOCOL_VERSION => parse_negotiate_protocol_version(&mut cur),
186 _ => Err(ProtocolError::UnknownMessageType(ty)),
187 }
188 }
189}
190
191fn parse_authentication(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
192 let auth_type = cur.read_i32()?;
193 match auth_type {
194 auth_type::OK => Ok(BackendMessage::AuthenticationOk),
195 auth_type::CLEARTEXT_PASSWORD => Ok(BackendMessage::AuthenticationCleartextPassword),
196 auth_type::MD5_PASSWORD => {
197 let salt = cur.read_bytes(4)?;
198 let mut buf = [0_u8; 4];
199 buf.copy_from_slice(salt);
200 Ok(BackendMessage::AuthenticationMD5Password(buf))
201 }
202 auth_type::SASL => {
203 let mut mechanisms = Vec::new();
204 loop {
205 let mech = cur.read_cstring()?;
206 if mech.is_empty() {
207 break;
208 }
209 mechanisms.push(mech);
210 }
211 Ok(BackendMessage::AuthenticationSASL(mechanisms))
212 }
213 auth_type::SASL_CONTINUE => Ok(BackendMessage::AuthenticationSASLContinue(
214 cur.take_remaining(),
215 )),
216 auth_type::SASL_FINAL => Ok(BackendMessage::AuthenticationSASLFinal(
217 cur.take_remaining(),
218 )),
219 _ => Err(ProtocolError::InvalidField("unknown auth type")),
220 }
221}
222
223fn parse_backend_key_data(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
224 let process_id = cur.read_i32()?;
225 let secret_key = cur.read_i32()?;
226 Ok(BackendMessage::BackendKeyData {
227 process_id,
228 secret_key,
229 })
230}
231
232fn parse_parameter_status(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
233 let name = cur.read_cstring()?;
234 let value = cur.read_cstring()?;
235 Ok(BackendMessage::ParameterStatus { name, value })
236}
237
238fn parse_ready_for_query(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
239 let status = cur.read_u8()?;
240 let status = TransactionStatus::from_byte(status)
241 .ok_or(ProtocolError::InvalidField("invalid transaction status"))?;
242 Ok(BackendMessage::ReadyForQuery(status))
243}
244
245fn parse_row_description(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
246 let count = cur.read_i16()?;
247 if count < 0 {
248 return Err(ProtocolError::InvalidField("negative field count"));
249 }
250 let mut fields = Vec::with_capacity(count as usize);
251 for _ in 0..count {
252 let name = cur.read_cstring()?;
253 let table_oid = cur.read_u32()?;
254 let column_id = cur.read_i16()?;
255 let type_oid = cur.read_u32()?;
256 let type_size = cur.read_i16()?;
257 let type_modifier = cur.read_i32()?;
258 let format = cur.read_i16()?;
259 fields.push(FieldDescription {
260 name,
261 table_oid,
262 column_id,
263 type_oid,
264 type_size,
265 type_modifier,
266 format,
267 });
268 }
269 Ok(BackendMessage::RowDescription(fields))
270}
271
272fn parse_data_row(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
273 let count = cur.read_i16()?;
274 if count < 0 {
275 return Err(ProtocolError::InvalidField("negative column count"));
276 }
277 let mut values = Vec::with_capacity(count as usize);
278 for _ in 0..count {
279 let len = cur.read_i32()?;
280 if len == -1 {
281 values.push(None);
282 continue;
283 }
284 if len < 0 {
285 return Err(ProtocolError::InvalidField("negative data length"));
286 }
287 let bytes = cur.read_bytes(len as usize)?.to_vec();
288 values.push(Some(bytes));
289 }
290 Ok(BackendMessage::DataRow(values))
291}
292
293fn parse_command_complete(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
294 let tag = cur.read_cstring()?;
295 Ok(BackendMessage::CommandComplete(tag))
296}
297
298fn parse_parameter_description(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
299 let count = cur.read_i16()?;
300 if count < 0 {
301 return Err(ProtocolError::InvalidField("negative parameter count"));
302 }
303 let mut oids = Vec::with_capacity(count as usize);
304 for _ in 0..count {
305 oids.push(cur.read_u32()?);
306 }
307 Ok(BackendMessage::ParameterDescription(oids))
308}
309
310fn parse_copy_in_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
311 let format = cur.read_i8()?;
312 let column_formats = read_column_formats(cur)?;
313 Ok(BackendMessage::CopyInResponse {
314 format,
315 column_formats,
316 })
317}
318
319fn parse_copy_out_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
320 let format = cur.read_i8()?;
321 let column_formats = read_column_formats(cur)?;
322 Ok(BackendMessage::CopyOutResponse {
323 format,
324 column_formats,
325 })
326}
327
328fn parse_copy_both_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
329 let format = cur.read_i8()?;
330 let column_formats = read_column_formats(cur)?;
331 Ok(BackendMessage::CopyBothResponse {
332 format,
333 column_formats,
334 })
335}
336
337fn read_column_formats(cur: &mut Cursor<'_>) -> Result<Vec<i16>, ProtocolError> {
338 let count = cur.read_i16()?;
339 if count < 0 {
340 return Err(ProtocolError::InvalidField("negative format count"));
341 }
342 let mut formats = Vec::with_capacity(count as usize);
343 for _ in 0..count {
344 formats.push(cur.read_i16()?);
345 }
346 Ok(formats)
347}
348
349fn parse_notification_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
350 let process_id = cur.read_i32()?;
351 let channel = cur.read_cstring()?;
352 let payload = cur.read_cstring()?;
353 Ok(BackendMessage::NotificationResponse {
354 process_id,
355 channel,
356 payload,
357 })
358}
359
360fn parse_function_call_response(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
361 let len = cur.read_i32()?;
362 if len == -1 {
363 return Ok(BackendMessage::FunctionCallResponse(None));
364 }
365 if len < 0 {
366 return Err(ProtocolError::InvalidField("negative function length"));
367 }
368 let bytes = cur.read_bytes(len as usize)?.to_vec();
369 Ok(BackendMessage::FunctionCallResponse(Some(bytes)))
370}
371
372fn parse_negotiate_protocol_version(cur: &mut Cursor<'_>) -> Result<BackendMessage, ProtocolError> {
373 let newest_minor = cur.read_i32()?;
374 let count = cur.read_i32()?;
375 if count < 0 {
376 return Err(ProtocolError::InvalidField(
377 "negative protocol option count",
378 ));
379 }
380 let mut unrecognized = Vec::with_capacity(count as usize);
381 for _ in 0..count {
382 unrecognized.push(cur.read_cstring()?);
383 }
384 Ok(BackendMessage::NegotiateProtocolVersion {
385 newest_minor,
386 unrecognized,
387 })
388}
389
390fn parse_error_response(
391 cur: &mut Cursor<'_>,
392 is_error: bool,
393) -> Result<BackendMessage, ProtocolError> {
394 let mut fields = ErrorFields::default();
395 loop {
396 let code = cur.read_u8()?;
397 if code == 0 {
398 break;
399 }
400 let value = cur.read_cstring()?;
401 match code {
402 b'S' => fields.severity = value,
403 b'V' => fields.severity_localized = Some(value),
404 b'C' => fields.code = value,
405 b'M' => fields.message = value,
406 b'D' => fields.detail = Some(value),
407 b'H' => fields.hint = Some(value),
408 b'P' => fields.position = value.parse().ok(),
409 b'p' => fields.internal_position = value.parse().ok(),
410 b'q' => fields.internal_query = Some(value),
411 b'W' => fields.where_ = Some(value),
412 b's' => fields.schema = Some(value),
413 b't' => fields.table = Some(value),
414 b'c' => fields.column = Some(value),
415 b'd' => fields.data_type = Some(value),
416 b'n' => fields.constraint = Some(value),
417 b'F' => fields.file = Some(value),
418 b'L' => fields.line = value.parse().ok(),
419 b'R' => fields.routine = Some(value),
420 _ => {
421 }
423 }
424 }
425
426 if is_error {
427 Ok(BackendMessage::ErrorResponse(fields))
428 } else {
429 Ok(BackendMessage::NoticeResponse(fields))
430 }
431}
432
433#[derive(Debug)]
434struct Cursor<'a> {
435 buf: &'a [u8],
436 pos: usize,
437}
438
439impl<'a> Cursor<'a> {
440 fn new(buf: &'a [u8]) -> Self {
441 Self { buf, pos: 0 }
442 }
443
444 fn remaining(&self) -> usize {
445 self.buf.len().saturating_sub(self.pos)
446 }
447
448 fn read_u8(&mut self) -> Result<u8, ProtocolError> {
449 if self.remaining() < 1 {
450 return Err(ProtocolError::UnexpectedEof);
451 }
452 let b = self.buf[self.pos];
453 self.pos += 1;
454 Ok(b)
455 }
456
457 fn read_i8(&mut self) -> Result<i8, ProtocolError> {
458 let b = self.read_u8()?;
459 Ok(b as i8)
460 }
461
462 fn read_i16(&mut self) -> Result<i16, ProtocolError> {
463 let bytes = self.read_bytes(2)?;
464 Ok(i16::from_be_bytes([bytes[0], bytes[1]]))
465 }
466
467 fn read_u32(&mut self) -> Result<u32, ProtocolError> {
468 let bytes = self.read_bytes(4)?;
469 Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
470 }
471
472 fn read_i32(&mut self) -> Result<i32, ProtocolError> {
473 let bytes = self.read_bytes(4)?;
474 Ok(i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
475 }
476
477 fn read_bytes(&mut self, n: usize) -> Result<&'a [u8], ProtocolError> {
478 if self.remaining() < n {
479 return Err(ProtocolError::UnexpectedEof);
480 }
481 let start = self.pos;
482 let end = self.pos + n;
483 self.pos = end;
484 Ok(&self.buf[start..end])
485 }
486
487 fn read_cstring(&mut self) -> Result<String, ProtocolError> {
488 let start = self.pos;
489 while self.pos < self.buf.len() && self.buf[self.pos] != 0 {
490 self.pos += 1;
491 }
492 if self.pos >= self.buf.len() {
493 return Err(ProtocolError::UnexpectedEof);
494 }
495 let bytes = self.buf[start..self.pos].to_vec();
496 self.pos += 1; Ok(String::from_utf8(bytes)?)
498 }
499
500 fn take_remaining(&mut self) -> Vec<u8> {
501 let remaining = self.buf[self.pos..].to_vec();
502 self.pos = self.buf.len();
503 remaining
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510
511 #[allow(clippy::cast_possible_truncation)]
512 fn build_message(ty: u8, payload: &[u8]) -> Vec<u8> {
513 let mut buf = Vec::new();
514 buf.push(ty);
515 let len = (payload.len() + 4) as i32;
516 buf.extend_from_slice(&len.to_be_bytes());
517 buf.extend_from_slice(payload);
518 buf
519 }
520
521 #[test]
522 fn parse_auth_ok() {
523 let mut payload = Vec::new();
524 payload.extend_from_slice(&auth_type::OK.to_be_bytes());
525 let msg = build_message(backend_type::AUTHENTICATION, &payload);
526 let decoded = MessageReader::parse_message(&msg).unwrap();
527 assert!(matches!(decoded, BackendMessage::AuthenticationOk));
528 }
529
530 #[test]
531 fn parse_ready_for_query() {
532 let payload = [TransactionStatus::Idle.as_byte()];
533 let msg = build_message(backend_type::READY_FOR_QUERY, &payload);
534 let decoded = MessageReader::parse_message(&msg).unwrap();
535 assert!(matches!(
536 decoded,
537 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
538 ));
539 }
540
541 #[test]
542 fn parse_error_response() {
543 let mut payload = Vec::new();
544 payload.push(b'S');
545 payload.extend_from_slice(b"ERROR\0");
546 payload.push(b'C');
547 payload.extend_from_slice(b"12345\0");
548 payload.push(b'M');
549 payload.extend_from_slice(b"bad\0");
550 payload.push(0);
551
552 let msg = build_message(backend_type::ERROR_RESPONSE, &payload);
553 let decoded = MessageReader::parse_message(&msg).unwrap();
554 match decoded {
555 BackendMessage::ErrorResponse(fields) => {
556 assert_eq!(fields.severity, "ERROR");
557 assert_eq!(fields.code, "12345");
558 assert_eq!(fields.message, "bad");
559 }
560 _ => panic!("unexpected message"),
561 }
562 }
563
564 #[test]
565 fn parse_data_row() {
566 let mut payload = Vec::new();
567 payload.extend_from_slice(&(2_i16).to_be_bytes());
568 payload.extend_from_slice(&(3_i32).to_be_bytes());
569 payload.extend_from_slice(b"foo");
570 payload.extend_from_slice(&(-1_i32).to_be_bytes());
571
572 let msg = build_message(backend_type::DATA_ROW, &payload);
573 let decoded = MessageReader::parse_message(&msg).unwrap();
574 match decoded {
575 BackendMessage::DataRow(values) => {
576 assert_eq!(values.len(), 2);
577 assert_eq!(values[0].as_deref(), Some(b"foo".as_slice()));
578 assert!(values[1].is_none());
579 }
580 _ => panic!("unexpected message"),
581 }
582 }
583
584 #[test]
585 fn reader_buffers_partial_frames() {
586 let payload = [TransactionStatus::Idle.as_byte()];
587 let msg = build_message(backend_type::READY_FOR_QUERY, &payload);
588 let (left, right) = msg.split_at(3);
589
590 let mut reader = MessageReader::new();
591 let first = reader.feed(left).unwrap();
592 assert!(first.is_empty());
593
594 let second = reader.feed(right).unwrap();
595 assert_eq!(second.len(), 1);
596 }
597
598 #[test]
599 fn parse_row_description_negative_count_rejected() {
600 let payload = (-1_i16).to_be_bytes();
602 let msg = build_message(backend_type::ROW_DESCRIPTION, &payload);
603 let result = MessageReader::parse_message(&msg);
604 assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
605 }
606
607 #[test]
608 fn parse_data_row_negative_count_rejected() {
609 let payload = (-1_i16).to_be_bytes();
611 let msg = build_message(backend_type::DATA_ROW, &payload);
612 let result = MessageReader::parse_message(&msg);
613 assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
614 }
615
616 #[test]
617 fn parse_parameter_description_negative_count_rejected() {
618 let payload = (-1_i16).to_be_bytes();
620 let msg = build_message(backend_type::PARAMETER_DESCRIPTION, &payload);
621 let result = MessageReader::parse_message(&msg);
622 assert!(matches!(result, Err(ProtocolError::InvalidField(_))));
623 }
624}