1use bytes::{Buf, BufMut, BytesMut};
2use std::collections::HashMap;
3use std::io;
4use thiserror::Error;
5
6#[derive(Debug, Clone, PartialEq)]
11pub struct SelectiveUpdatesConfig {
12 pub enabled: Option<bool>,
14 pub min_changed_columns: Option<usize>,
17 pub max_changed_columns_ratio: Option<f64>,
20}
21
22#[derive(Debug, Error)]
24pub enum ProtocolError {
25 #[error("I/O error: {0}")]
26 Io(#[from] io::Error),
27
28 #[error("Invalid message type: {0}")]
29 InvalidMessageType(u8),
30
31 #[error("Message too short")]
32 MessageTooShort,
33
34 #[error("Invalid message length: {0}")]
35 InvalidMessageLength(i32),
36
37 #[error("Invalid string encoding")]
38 InvalidString,
39
40 #[error("Unexpected message: {0}")]
41 #[allow(dead_code)]
42 UnexpectedMessage(String),
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47#[repr(u8)]
48pub enum SubscriptionUpdateType {
49 Full = 0,
50 DeltaInsert = 1,
51 DeltaUpdate = 2,
52 DeltaDelete = 3,
53 SelectiveUpdate = 4,
56}
57
58#[derive(Debug, Clone, PartialEq)]
63pub struct PartialRowUpdate {
64 pub total_columns: u16,
66 pub column_mask: Vec<u8>,
70 pub values: Vec<Option<Vec<u8>>>,
73}
74
75impl PartialRowUpdate {
76 pub fn new(total_columns: u16, present_columns: &[u16], values: Vec<Option<Vec<u8>>>) -> Self {
83 debug_assert_eq!(present_columns.len(), values.len());
84
85 let bitmap_bytes = (total_columns as usize).div_ceil(8);
87 let mut column_mask = vec![0u8; bitmap_bytes];
88
89 for &col_idx in present_columns {
90 if (col_idx as usize) < total_columns as usize {
91 let byte_idx = col_idx as usize / 8;
92 let bit_idx = col_idx as usize % 8;
93 column_mask[byte_idx] |= 1 << bit_idx;
94 }
95 }
96
97 Self { total_columns, column_mask, values }
98 }
99
100 pub fn is_column_present(&self, col_idx: u16) -> bool {
102 if col_idx >= self.total_columns {
103 return false;
104 }
105 let byte_idx = col_idx as usize / 8;
106 let bit_idx = col_idx as usize % 8;
107 if byte_idx < self.column_mask.len() {
108 (self.column_mask[byte_idx] & (1 << bit_idx)) != 0
109 } else {
110 false
111 }
112 }
113
114 pub fn present_column_count(&self) -> usize {
116 self.column_mask.iter().map(|b| b.count_ones() as usize).sum()
117 }
118}
119
120#[derive(Debug, Clone, PartialEq)]
122pub enum BackendMessage {
123 AuthenticationOk,
125 #[allow(dead_code)]
126 AuthenticationCleartextPassword,
127 #[allow(dead_code)]
128 AuthenticationMD5Password { salt: [u8; 4] },
129
130 ParameterStatus { name: String, value: String },
132
133 BackendKeyData { process_id: i32, secret_key: i32 },
135
136 ReadyForQuery { status: TransactionStatus },
138
139 RowDescription { fields: Vec<FieldDescription> },
141
142 DataRow { values: Vec<Option<Vec<u8>>> },
144
145 CommandComplete { tag: String },
147
148 ErrorResponse { fields: HashMap<u8, String> },
150
151 #[allow(dead_code)]
153 NoticeResponse { fields: HashMap<u8, String> },
154
155 EmptyQueryResponse,
157
158 SubscriptionData {
160 subscription_id: [u8; 16],
161 update_type: SubscriptionUpdateType,
162 rows: Vec<Vec<Option<Vec<u8>>>>,
163 },
164
165 SubscriptionError { subscription_id: [u8; 16], message: String },
167
168SubscriptionAck {
171 subscription_id: [u8; 16],
172 table_count: u16,
174 },
175
176 SubscriptionPartialData {
194 subscription_id: [u8; 16],
195 rows: Vec<PartialRowUpdate>,
197 },
198}
199
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
202pub enum TransactionStatus {
203 Idle,
205 #[allow(dead_code)]
207 InTransaction,
208 #[allow(dead_code)]
210 FailedTransaction,
211}
212
213impl TransactionStatus {
214 pub fn as_byte(&self) -> u8 {
215 match self {
216 TransactionStatus::Idle => b'I',
217 TransactionStatus::InTransaction => b'T',
218 TransactionStatus::FailedTransaction => b'E',
219 }
220 }
221}
222
223#[derive(Debug, Clone, PartialEq)]
225pub struct FieldDescription {
226 pub name: String,
227 pub table_oid: i32,
228 pub column_attr_number: i16,
229 pub data_type_oid: i32,
230 pub data_type_size: i16,
231 pub type_modifier: i32,
232 pub format_code: i16, }
234
235#[derive(Debug, Clone, PartialEq)]
237pub enum FrontendMessage {
238 Startup { protocol_version: i32, params: HashMap<String, String> },
240
241 Password { password: String },
243
244 Query { query: String },
246
247 Terminate,
249
250 SSLRequest,
252
253 Subscribe {
257 query: String,
258 params: Vec<Option<Vec<u8>>>,
259 filter: Option<String>,
260 selective_updates_config: Option<SelectiveUpdatesConfig>,
261 },
262
263 Unsubscribe { subscription_id: [u8; 16] },
265
266 SubscriptionPause { subscription_id: [u8; 16] },
268
269 SubscriptionResume { subscription_id: [u8; 16] },
271}
272
273impl BackendMessage {
274 pub fn encode(&self, buf: &mut BytesMut) {
276 match self {
277 BackendMessage::AuthenticationOk => {
278 buf.put_u8(b'R'); buf.put_i32(8); buf.put_i32(0); }
282
283 BackendMessage::AuthenticationCleartextPassword => {
284 buf.put_u8(b'R');
285 buf.put_i32(8);
286 buf.put_i32(3); }
288
289 BackendMessage::AuthenticationMD5Password { salt } => {
290 buf.put_u8(b'R');
291 buf.put_i32(12);
292 buf.put_i32(5); buf.put_slice(salt);
294 }
295
296 BackendMessage::ParameterStatus { name, value } => {
297 buf.put_u8(b'S'); let len = 4 + name.len() + 1 + value.len() + 1;
299 buf.put_i32(len as i32);
300 put_cstring(buf, name);
301 put_cstring(buf, value);
302 }
303
304 BackendMessage::BackendKeyData { process_id, secret_key } => {
305 buf.put_u8(b'K'); buf.put_i32(12);
307 buf.put_i32(*process_id);
308 buf.put_i32(*secret_key);
309 }
310
311 BackendMessage::ReadyForQuery { status } => {
312 buf.put_u8(b'Z'); buf.put_i32(5);
314 buf.put_u8(status.as_byte());
315 }
316
317 BackendMessage::RowDescription { fields } => {
318 buf.put_u8(b'T'); let mut len = 4 + 2; for field in fields {
323 len += field.name.len() + 1 + 18; }
325
326 buf.put_i32(len as i32);
327 buf.put_i16(fields.len() as i16);
328
329 for field in fields {
330 put_cstring(buf, &field.name);
331 buf.put_i32(field.table_oid);
332 buf.put_i16(field.column_attr_number);
333 buf.put_i32(field.data_type_oid);
334 buf.put_i16(field.data_type_size);
335 buf.put_i32(field.type_modifier);
336 buf.put_i16(field.format_code);
337 }
338 }
339
340 BackendMessage::DataRow { values } => {
341 buf.put_u8(b'D'); let mut len = 4 + 2; for value in values {
346 len += 4; if let Some(v) = value {
348 len += v.len();
349 }
350 }
351
352 buf.put_i32(len as i32);
353 buf.put_i16(values.len() as i16);
354
355 for value in values {
356 match value {
357 Some(v) => {
358 buf.put_i32(v.len() as i32);
359 buf.put_slice(v);
360 }
361 None => {
362 buf.put_i32(-1); }
364 }
365 }
366 }
367
368 BackendMessage::CommandComplete { tag } => {
369 buf.put_u8(b'C'); let len = 4 + tag.len() + 1;
371 buf.put_i32(len as i32);
372 put_cstring(buf, tag);
373 }
374
375 BackendMessage::ErrorResponse { fields } => {
376 buf.put_u8(b'E'); encode_notice_or_error(buf, fields);
378 }
379
380 BackendMessage::NoticeResponse { fields } => {
381 buf.put_u8(b'N'); encode_notice_or_error(buf, fields);
383 }
384
385 BackendMessage::EmptyQueryResponse => {
386 buf.put_u8(b'I'); buf.put_i32(4);
388 }
389
390 BackendMessage::SubscriptionData { subscription_id, update_type, rows } => {
391 buf.put_u8(0xF2); let mut len = 4 + 16 + 1 + 4; for row in rows {
396 len += 2; for value in row {
398 len += 4; if let Some(v) = value {
400 len += v.len();
401 }
402 }
403 }
404
405 buf.put_i32(len as i32);
406 buf.put_slice(subscription_id);
407 buf.put_u8(*update_type as u8);
408 buf.put_i32(rows.len() as i32);
409
410 for row in rows {
411 buf.put_i16(row.len() as i16);
412 for value in row {
413 match value {
414 Some(v) => {
415 buf.put_i32(v.len() as i32);
416 buf.put_slice(v);
417 }
418 None => {
419 buf.put_i32(-1); }
421 }
422 }
423 }
424 }
425
426 BackendMessage::SubscriptionError { subscription_id, message } => {
427 buf.put_u8(0xF3); let msg_bytes = message.as_bytes();
430 let len = 4 + 16 + msg_bytes.len() + 1; buf.put_i32(len as i32);
433 buf.put_slice(subscription_id);
434 put_cstring(buf, message);
435 }
436
437BackendMessage::SubscriptionAck { subscription_id, table_count } => {
438 buf.put_u8(0xF4); let len: i32 = 4 + 16 + 2; buf.put_i32(len);
443 buf.put_slice(subscription_id);
444 buf.put_u16(*table_count);
445 }
446
447 BackendMessage::SubscriptionPartialData { subscription_id, rows } => {
448 buf.put_u8(0xF7); let mut len = 4 + 16 + 1 + 4;
453 for row in rows {
454 len += 2;
456 len += row.column_mask.len();
457 for value in &row.values {
458 len += 4; if let Some(v) = value {
460 len += v.len();
461 }
462 }
463 }
464
465 buf.put_i32(len as i32);
466 buf.put_slice(subscription_id);
467 buf.put_u8(SubscriptionUpdateType::SelectiveUpdate as u8);
468 buf.put_i32(rows.len() as i32);
469
470 for row in rows {
471 buf.put_i16(row.total_columns as i16);
472 buf.put_slice(&row.column_mask);
473 for value in &row.values {
474 match value {
475 Some(v) => {
476 buf.put_i32(v.len() as i32);
477 buf.put_slice(v);
478 }
479 None => {
480 buf.put_i32(-1); }
482 }
483 }
484 }
485 }
486 }
487 }
488}
489
490impl FrontendMessage {
491 pub fn decode(buf: &mut BytesMut) -> Result<Option<Self>, ProtocolError> {
493 if buf.len() < 5 {
495 return Ok(None);
496 }
497
498 let msg_type = buf[0];
500
501 let len_i32 = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]);
503
504 if len_i32 < 4 {
507 return Err(ProtocolError::InvalidMessageLength(len_i32));
508 }
509
510 let len = len_i32 as usize;
511
512 let total_len = 1usize.saturating_add(len);
514 if buf.len() < total_len {
515 return Ok(None);
516 }
517
518 buf.advance(1);
520
521 match msg_type {
523 b'Q' => {
524 buf.advance(4); let query = read_cstring(buf)?;
527 Ok(Some(FrontendMessage::Query { query }))
528 }
529
530 b'p' => {
531 buf.advance(4); let password = read_cstring(buf)?;
534 Ok(Some(FrontendMessage::Password { password }))
535 }
536
537 b'X' => {
538 buf.advance(4); Ok(Some(FrontendMessage::Terminate))
541 }
542
543 0xF0 => {
544 buf.advance(4); let query = read_cstring(buf)?;
547 let param_count = buf.get_i16() as usize;
548 let mut params = Vec::with_capacity(param_count);
549
550 for _ in 0..param_count {
551 let param_len = buf.get_i32();
552 if param_len < 0 {
553 params.push(None);
554 } else {
555 let mut param = vec![0u8; param_len as usize];
556 buf.copy_to_slice(&mut param);
557 params.push(Some(param));
558 }
559 }
560
561 let filter = if buf.remaining() >= 2 {
564 let filter_len = buf.get_i16();
565 if filter_len > 0 {
566 let filter_len = filter_len as usize;
567 if buf.remaining() >= filter_len {
568 let mut filter_bytes = vec![0u8; filter_len];
569 buf.copy_to_slice(&mut filter_bytes);
570 Some(
571 String::from_utf8(filter_bytes)
572 .map_err(|_| ProtocolError::InvalidString)?,
573 )
574 } else {
575 None }
577 } else {
578 None }
580 } else {
581 None };
583
584 let selective_updates_config = if buf.remaining() >= 1 {
590 let config_flags = buf.get_u8();
591 if config_flags != 0 {
592 let mut config = SelectiveUpdatesConfig {
593 enabled: None,
594 min_changed_columns: None,
595 max_changed_columns_ratio: None,
596 };
597
598 if (config_flags & 0x01) != 0 && buf.remaining() >= 1 {
600 config.enabled = Some(buf.get_u8() != 0);
601 }
602
603 if (config_flags & 0x02) != 0 && buf.remaining() >= 2 {
605 config.min_changed_columns = Some(buf.get_u16() as usize);
606 }
607
608 if (config_flags & 0x04) != 0 && buf.remaining() >= 8 {
610 config.max_changed_columns_ratio = Some(buf.get_f64());
611 }
612
613 Some(config)
614 } else {
615 None }
617 } else {
618 None };
620
621 Ok(Some(FrontendMessage::Subscribe { query, params, filter, selective_updates_config }))
622 }
623
624 0xF1 => {
625 buf.advance(4); let mut subscription_id = [0u8; 16];
628 buf.copy_to_slice(&mut subscription_id);
629 Ok(Some(FrontendMessage::Unsubscribe { subscription_id }))
630 }
631
632 0xF5 => {
633 buf.advance(4); let mut subscription_id = [0u8; 16];
636 buf.copy_to_slice(&mut subscription_id);
637 Ok(Some(FrontendMessage::SubscriptionPause { subscription_id }))
638 }
639
640 0xF6 => {
641 buf.advance(4); let mut subscription_id = [0u8; 16];
644 buf.copy_to_slice(&mut subscription_id);
645 Ok(Some(FrontendMessage::SubscriptionResume { subscription_id }))
646 }
647
648 _ => Err(ProtocolError::InvalidMessageType(msg_type)),
649 }
650 }
651
652 pub fn decode_startup(buf: &mut BytesMut) -> Result<Option<Self>, ProtocolError> {
654 if buf.len() < 4 {
655 return Ok(None);
656 }
657
658 let len_i32 = i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
659
660 if len_i32 < 8 {
663 return Err(ProtocolError::InvalidMessageLength(len_i32));
664 }
665
666 let len = len_i32 as usize;
667
668 if buf.len() < len {
669 return Ok(None);
670 }
671
672 buf.advance(4); let protocol_version = buf.get_i32();
675
676 if protocol_version == 80877103 {
678 return Ok(Some(FrontendMessage::SSLRequest));
679 }
680
681 let mut params = HashMap::new();
683 let max_params = 100; for _ in 0..max_params {
685 if buf.is_empty() {
687 break;
688 }
689 let key = read_cstring(buf)?;
690 if key.is_empty() {
691 break;
692 }
693 let value = read_cstring(buf)?;
694 params.insert(key, value);
695 }
696
697 Ok(Some(FrontendMessage::Startup { protocol_version, params }))
698 }
699}
700
701fn put_cstring(buf: &mut BytesMut, s: &str) {
703 buf.put_slice(s.as_bytes());
704 buf.put_u8(0);
705}
706
707fn read_cstring(buf: &mut BytesMut) -> Result<String, ProtocolError> {
709 let null_pos = buf.iter().position(|&b| b == 0).ok_or(ProtocolError::InvalidString)?;
710
711 let bytes = buf.split_to(null_pos);
712 buf.advance(1); String::from_utf8(bytes.to_vec()).map_err(|_| ProtocolError::InvalidString)
715}
716
717fn encode_notice_or_error(buf: &mut BytesMut, fields: &HashMap<u8, String>) {
719 let mut len = 4 + 1; for value in fields.values() {
722 len += 1 + value.len() + 1; }
724
725 buf.put_i32(len as i32);
726
727 for (&field_type, value) in fields {
729 buf.put_u8(field_type);
730 put_cstring(buf, value);
731 }
732
733 buf.put_u8(0);
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740
741 #[test]
742 fn test_authentication_ok_encoding() {
743 let mut buf = BytesMut::new();
744 BackendMessage::AuthenticationOk.encode(&mut buf);
745
746 assert_eq!(buf[0], b'R');
747 assert_eq!(&buf[1..5], &[0, 0, 0, 8]);
748 assert_eq!(&buf[5..9], &[0, 0, 0, 0]);
749 }
750
751 #[test]
752 fn test_ready_for_query_encoding() {
753 let mut buf = BytesMut::new();
754 BackendMessage::ReadyForQuery { status: TransactionStatus::Idle }.encode(&mut buf);
755
756 assert_eq!(buf[0], b'Z');
757 assert_eq!(&buf[1..5], &[0, 0, 0, 5]);
758 assert_eq!(buf[5], b'I');
759 }
760
761 #[test]
762 fn test_query_decoding() {
763 let mut buf = BytesMut::new();
764 buf.put_u8(b'Q'); buf.put_i32(13); buf.put_slice(b"SELECT 1\0");
767
768 let msg = FrontendMessage::decode(&mut buf).unwrap();
769 assert!(matches!(
770 msg,
771 Some(FrontendMessage::Query { query }) if query == "SELECT 1"
772 ));
773 }
774
775 #[test]
776 fn test_subscribe_message_parsing() {
777 let mut buf = BytesMut::new();
778 buf.put_u8(0xF0); let mut content = BytesMut::new();
780 content.put_slice(b"SELECT * FROM users\0");
781 content.put_i16(0); buf.put_i32((4 + content.len()) as i32);
784 buf.extend(content);
785
786 let msg = FrontendMessage::decode(&mut buf).unwrap();
787 assert!(matches!(
788 msg,
789 Some(FrontendMessage::Subscribe { query, params, filter, .. })
790 if query == "SELECT * FROM users" && params.is_empty() && filter.is_none()
791 ));
792 }
793
794 #[test]
795 fn test_subscribe_with_parameters() {
796 let mut buf = BytesMut::new();
797 buf.put_u8(0xF0); let mut content = BytesMut::new();
799 content.put_slice(b"SELECT * FROM users WHERE id = $1\0");
800 content.put_i16(1); content.put_i32(5); content.put_slice(b"12345");
803
804 buf.put_i32((4 + content.len()) as i32);
805 buf.extend(content);
806
807 let msg = FrontendMessage::decode(&mut buf).unwrap();
808 assert!(matches!(
809 msg,
810 Some(FrontendMessage::Subscribe { query, params, filter, .. })
811 if query == "SELECT * FROM users WHERE id = $1" && params.len() == 1 && filter.is_none()
812 ));
813 }
814
815 #[test]
816 fn test_subscribe_with_filter() {
817 let mut buf = BytesMut::new();
818 buf.put_u8(0xF0); let mut content = BytesMut::new();
820 content.put_slice(b"SELECT * FROM users\0");
821 content.put_i16(0); let filter_str = "status = 'active'";
823 content.put_i16(filter_str.len() as i16); content.put_slice(filter_str.as_bytes()); buf.put_i32((4 + content.len()) as i32);
827 buf.extend(content);
828
829 let msg = FrontendMessage::decode(&mut buf).unwrap();
830 match msg {
831 Some(FrontendMessage::Subscribe { query, params, filter, .. }) => {
832 assert_eq!(query, "SELECT * FROM users");
833 assert!(params.is_empty());
834 assert_eq!(filter, Some("status = 'active'".to_string()));
835 }
836 _ => panic!("Expected Subscribe message"),
837 }
838 }
839
840 #[test]
841 fn test_subscribe_with_empty_filter() {
842 let mut buf = BytesMut::new();
843 buf.put_u8(0xF0); let mut content = BytesMut::new();
845 content.put_slice(b"SELECT * FROM users\0");
846 content.put_i16(0); content.put_i16(0); buf.put_i32((4 + content.len()) as i32);
850 buf.extend(content);
851
852 let msg = FrontendMessage::decode(&mut buf).unwrap();
853 assert!(matches!(
854 msg,
855 Some(FrontendMessage::Subscribe { query, params, filter, .. })
856 if query == "SELECT * FROM users" && params.is_empty() && filter.is_none()
857 ));
858 }
859
860 #[test]
861 fn test_unsubscribe_message_parsing() {
862 let mut buf = BytesMut::new();
863 buf.put_u8(0xF1); buf.put_i32(20); buf.put_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
866
867 let msg = FrontendMessage::decode(&mut buf).unwrap();
868 assert!(matches!(
869 msg,
870 Some(FrontendMessage::Unsubscribe { subscription_id })
871 if subscription_id == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
872 ));
873 }
874
875 #[test]
876 fn test_subscription_data_encoding() {
877 let mut buf = BytesMut::new();
878 let subscription_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
879 let rows = vec![vec![Some(b"value1".to_vec()), Some(b"value2".to_vec())]];
880
881 let msg = BackendMessage::SubscriptionData {
882 subscription_id,
883 update_type: SubscriptionUpdateType::Full,
884 rows,
885 };
886 msg.encode(&mut buf);
887
888 assert_eq!(buf[0], 0xF2);
889 assert_eq!(&buf[5..21], subscription_id.as_ref());
891 }
892
893 #[test]
894 fn test_subscription_error_encoding() {
895 let mut buf = BytesMut::new();
896 let subscription_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
897
898 let msg = BackendMessage::SubscriptionError {
899 subscription_id,
900 message: "Query error".to_string(),
901 };
902 msg.encode(&mut buf);
903
904 assert_eq!(buf[0], 0xF3);
905 assert_eq!(&buf[5..21], subscription_id.as_ref());
907 }
908
909 #[test]
910 fn test_subscription_ack_encoding() {
911 let mut buf = BytesMut::new();
912 let subscription_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
913
914 let msg = BackendMessage::SubscriptionAck { subscription_id, table_count: 3 };
915 msg.encode(&mut buf);
916
917 assert_eq!(buf[0], 0xF4);
918 assert_eq!(&buf[1..5], &[0, 0, 0, 22]);
920 assert_eq!(&buf[5..21], subscription_id.as_ref());
922 assert_eq!(&buf[21..23], &[0, 3]);
924 }
925
926 #[test]
927 fn test_subscription_pause_parsing() {
928 let mut buf = BytesMut::new();
929 buf.put_u8(0xF5); buf.put_i32(20); buf.put_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
932
933 let msg = FrontendMessage::decode(&mut buf).unwrap();
934 assert!(matches!(
935 msg,
936 Some(FrontendMessage::SubscriptionPause { subscription_id })
937 if subscription_id == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
938 ));
939 }
940
941 #[test]
942 fn test_subscription_resume_parsing() {
943 let mut buf = BytesMut::new();
944 buf.put_u8(0xF6); buf.put_i32(20); buf.put_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
947
948 let msg = FrontendMessage::decode(&mut buf).unwrap();
949 assert!(matches!(
950 msg,
951 Some(FrontendMessage::SubscriptionResume { subscription_id })
952 if subscription_id == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
953 ));
954 }
955
956 #[test]
957 fn test_subscription_partial_data_encoding() {
958 let mut buf = BytesMut::new();
959 let subscription_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
960
961 let partial_row = PartialRowUpdate::new(
963 4,
964 &[0, 2],
965 vec![Some(b"id1".to_vec()), Some(b"value".to_vec())],
966 );
967
968 let msg = BackendMessage::SubscriptionPartialData {
969 subscription_id,
970 rows: vec![partial_row],
971 };
972 msg.encode(&mut buf);
973
974 assert_eq!(buf[0], 0xF7);
976
977 assert_eq!(&buf[5..21], subscription_id.as_ref());
979
980 assert_eq!(buf[21], 4);
982
983 let row_count = i32::from_be_bytes([buf[22], buf[23], buf[24], buf[25]]);
985 assert_eq!(row_count, 1);
986
987 let total_cols = i16::from_be_bytes([buf[26], buf[27]]);
989 assert_eq!(total_cols, 4);
990
991 assert_eq!(buf[28], 0b00000101);
994 }
995
996 #[test]
997 fn test_subscription_partial_data_encoding_with_null() {
998 let mut buf = BytesMut::new();
999 let subscription_id = [0u8; 16];
1000
1001 let partial_row = PartialRowUpdate::new(
1003 3,
1004 &[0, 1],
1005 vec![Some(b"1".to_vec()), None], );
1007
1008 let msg = BackendMessage::SubscriptionPartialData {
1009 subscription_id,
1010 rows: vec![partial_row],
1011 };
1012 msg.encode(&mut buf);
1013
1014 assert_eq!(buf[0], 0xF7);
1015
1016 let null_pos = 34;
1025 let null_len = i32::from_be_bytes([buf[null_pos], buf[null_pos + 1], buf[null_pos + 2], buf[null_pos + 3]]);
1026 assert_eq!(null_len, -1);
1027 }
1028
1029 #[test]
1030 fn test_partial_row_update_new() {
1031 let partial = PartialRowUpdate::new(
1033 16,
1034 &[0, 8, 15],
1035 vec![Some(b"a".to_vec()), Some(b"b".to_vec()), Some(b"c".to_vec())],
1036 );
1037
1038 assert_eq!(partial.total_columns, 16);
1039 assert_eq!(partial.column_mask.len(), 2); assert_eq!(partial.column_mask[0], 0b00000001);
1044 assert_eq!(partial.column_mask[1], 0b10000001);
1045
1046 assert!(partial.is_column_present(0));
1047 assert!(!partial.is_column_present(1));
1048 assert!(partial.is_column_present(8));
1049 assert!(partial.is_column_present(15));
1050 assert!(!partial.is_column_present(16)); }
1052
1053 mod malformed_message_tests {
1059 use super::*;
1060
1061 #[test]
1066 fn test_truncated_message_empty_buffer() {
1067 let mut buf = BytesMut::new();
1068 let result = FrontendMessage::decode(&mut buf);
1070 assert!(result.is_ok());
1071 assert!(result.unwrap().is_none());
1072 }
1073
1074 #[test]
1075 fn test_truncated_message_only_type_byte() {
1076 let mut buf = BytesMut::new();
1077 buf.put_u8(b'Q'); let result = FrontendMessage::decode(&mut buf);
1079 assert!(result.is_ok());
1080 assert!(result.unwrap().is_none());
1081 }
1082
1083 #[test]
1084 fn test_truncated_message_partial_length() {
1085 let mut buf = BytesMut::new();
1086 buf.put_u8(b'Q');
1087 buf.put_u8(0); buf.put_u8(0);
1089 let result = FrontendMessage::decode(&mut buf);
1090 assert!(result.is_ok());
1091 assert!(result.unwrap().is_none());
1092 }
1093
1094 #[test]
1095 fn test_truncated_message_incomplete_body() {
1096 let mut buf = BytesMut::new();
1097 buf.put_u8(b'Q');
1098 buf.put_i32(100); buf.put_slice(b"SELECT"); let result = FrontendMessage::decode(&mut buf);
1101 assert!(result.is_ok());
1102 assert!(result.unwrap().is_none());
1103 }
1104
1105 #[test]
1106 fn test_truncated_startup_empty_buffer() {
1107 let mut buf = BytesMut::new();
1108 let result = FrontendMessage::decode_startup(&mut buf);
1109 assert!(result.is_ok());
1110 assert!(result.unwrap().is_none());
1111 }
1112
1113 #[test]
1114 fn test_truncated_startup_partial_length() {
1115 let mut buf = BytesMut::new();
1116 buf.put_u8(0);
1117 buf.put_u8(0); let result = FrontendMessage::decode_startup(&mut buf);
1119 assert!(result.is_ok());
1120 assert!(result.unwrap().is_none());
1121 }
1122
1123 #[test]
1124 fn test_truncated_startup_incomplete_body() {
1125 let mut buf = BytesMut::new();
1126 buf.put_i32(50); buf.put_i32(196608); buf.put_slice(b"user\0"); let result = FrontendMessage::decode_startup(&mut buf);
1130 assert!(result.is_ok());
1131 assert!(result.unwrap().is_none());
1132 }
1133
1134 #[test]
1139 fn test_invalid_message_type_byte() {
1140 let mut buf = BytesMut::new();
1141 buf.put_u8(0xFF); buf.put_i32(4); let result = FrontendMessage::decode(&mut buf);
1144 assert!(matches!(result, Err(ProtocolError::InvalidMessageType(0xFF))));
1145 }
1146
1147 #[test]
1148 fn test_invalid_message_type_zero() {
1149 let mut buf = BytesMut::new();
1150 buf.put_u8(0x00); buf.put_i32(4);
1152 let result = FrontendMessage::decode(&mut buf);
1153 assert!(matches!(result, Err(ProtocolError::InvalidMessageType(0x00))));
1154 }
1155
1156 #[test]
1157 fn test_invalid_message_type_lowercase_q() {
1158 let mut buf = BytesMut::new();
1160 buf.put_u8(b'q');
1161 buf.put_i32(13);
1162 buf.put_slice(b"SELECT 1\0");
1163 let result = FrontendMessage::decode(&mut buf);
1164 assert!(matches!(result, Err(ProtocolError::InvalidMessageType(b'q'))));
1165 }
1166
1167 #[test]
1168 fn test_invalid_message_type_numeric() {
1169 let mut buf = BytesMut::new();
1170 buf.put_u8(b'1'); buf.put_i32(4);
1172 let result = FrontendMessage::decode(&mut buf);
1173 assert!(matches!(result, Err(ProtocolError::InvalidMessageType(b'1'))));
1174 }
1175
1176 #[test]
1181 fn test_length_zero() {
1182 let mut buf = BytesMut::new();
1183 buf.put_u8(b'X'); buf.put_i32(0); let result = FrontendMessage::decode(&mut buf);
1186 assert!(matches!(result, Err(ProtocolError::InvalidMessageLength(0))));
1188 }
1189
1190 #[test]
1191 fn test_length_negative() {
1192 let mut buf = BytesMut::new();
1193 buf.put_u8(b'X');
1194 buf.put_i32(-1); let result = FrontendMessage::decode(&mut buf);
1196 assert!(matches!(result, Err(ProtocolError::InvalidMessageLength(-1))));
1198 }
1199
1200 #[test]
1201 fn test_length_too_small() {
1202 let mut buf = BytesMut::new();
1203 buf.put_u8(b'X');
1204 buf.put_i32(3); let result = FrontendMessage::decode(&mut buf);
1206 assert!(matches!(result, Err(ProtocolError::InvalidMessageLength(3))));
1207 }
1208
1209 #[test]
1210 fn test_startup_length_too_small() {
1211 let mut buf = BytesMut::new();
1212 buf.put_i32(4); let result = FrontendMessage::decode_startup(&mut buf);
1214 assert!(matches!(result, Err(ProtocolError::InvalidMessageLength(4))));
1216 }
1217
1218 #[test]
1219 fn test_startup_length_negative() {
1220 let mut buf = BytesMut::new();
1221 buf.put_i32(-1); let result = FrontendMessage::decode_startup(&mut buf);
1223 assert!(matches!(result, Err(ProtocolError::InvalidMessageLength(-1))));
1224 }
1225
1226 #[test]
1231 fn test_invalid_utf8_in_query() {
1232 let mut buf = BytesMut::new();
1233 buf.put_u8(b'Q');
1234 buf.put_i32(8); buf.put_slice(&[0xFF, 0xFE, 0x80]); buf.put_u8(0); let result = FrontendMessage::decode(&mut buf);
1238 assert!(matches!(result, Err(ProtocolError::InvalidString)));
1239 }
1240
1241 #[test]
1242 fn test_invalid_utf8_continuation_byte() {
1243 let mut buf = BytesMut::new();
1244 buf.put_u8(b'Q');
1245 buf.put_i32(6); buf.put_u8(0x80); buf.put_u8(0); let result = FrontendMessage::decode(&mut buf);
1249 assert!(matches!(result, Err(ProtocolError::InvalidString)));
1250 }
1251
1252 #[test]
1253 fn test_invalid_utf8_overlong_encoding() {
1254 let mut buf = BytesMut::new();
1255 buf.put_u8(b'Q');
1256 buf.put_i32(7);
1257 buf.put_slice(&[0xC0, 0x80]); buf.put_u8(0); let result = FrontendMessage::decode(&mut buf);
1260 assert!(matches!(result, Err(ProtocolError::InvalidString)));
1261 }
1262
1263 #[test]
1264 fn test_invalid_utf8_in_password() {
1265 let mut buf = BytesMut::new();
1266 buf.put_u8(b'p'); buf.put_i32(8);
1268 buf.put_slice(&[0xFE, 0xFF, 0x00]); buf.put_u8(0);
1270 let result = FrontendMessage::decode(&mut buf);
1271 assert!(result.is_ok() || matches!(result, Err(ProtocolError::InvalidString)));
1273 }
1274
1275 #[test]
1276 fn test_invalid_utf8_in_startup_user() {
1277 let mut buf = BytesMut::new();
1278 buf.put_i32(17);
1281 buf.put_i32(196608); buf.put_slice(b"user\0");
1283 buf.put_slice(&[0xFF, 0xFE]); buf.put_u8(0); buf.put_u8(0); let result = FrontendMessage::decode_startup(&mut buf);
1287 assert!(matches!(result, Err(ProtocolError::InvalidString)));
1289 }
1290
1291 #[test]
1296 fn test_query_missing_null_terminator() {
1297 let mut buf = BytesMut::new();
1298 buf.put_u8(b'Q');
1299 buf.put_i32(12); buf.put_slice(b"SELECT 1"); let result = FrontendMessage::decode(&mut buf);
1302 assert!(matches!(result, Err(ProtocolError::InvalidString)));
1303 }
1304
1305 #[test]
1306 fn test_startup_missing_final_null() {
1307 let mut buf = BytesMut::new();
1308 buf.put_i32(18);
1311 buf.put_i32(196608); buf.put_slice(b"user\0test\0"); let result = FrontendMessage::decode_startup(&mut buf);
1314 assert!(result.is_ok());
1317 let msg = result.unwrap();
1318 assert!(matches!(msg, Some(FrontendMessage::Startup { .. })));
1319 }
1320
1321 #[test]
1326 fn test_terminate_minimal() {
1327 let mut buf = BytesMut::new();
1329 buf.put_u8(b'X');
1330 buf.put_i32(4); let result = FrontendMessage::decode(&mut buf);
1332 assert!(result.is_ok());
1333 assert!(matches!(result.unwrap(), Some(FrontendMessage::Terminate)));
1334 }
1335
1336 #[test]
1337 fn test_query_empty_string() {
1338 let mut buf = BytesMut::new();
1339 buf.put_u8(b'Q');
1340 buf.put_i32(5); buf.put_u8(0); let result = FrontendMessage::decode(&mut buf);
1343 assert!(result.is_ok());
1344 assert!(matches!(
1345 result.unwrap(),
1346 Some(FrontendMessage::Query { query }) if query.is_empty()
1347 ));
1348 }
1349
1350 #[test]
1355 fn test_ssl_request_detection() {
1356 let mut buf = BytesMut::new();
1357 buf.put_i32(8); buf.put_i32(80877103); let result = FrontendMessage::decode_startup(&mut buf);
1360 assert!(result.is_ok());
1361 assert!(matches!(result.unwrap(), Some(FrontendMessage::SSLRequest)));
1362 }
1363
1364 #[test]
1369 fn test_startup_protocol_version_3_0() {
1370 let mut buf = BytesMut::new();
1371 buf.put_i32(17); buf.put_i32(196608); buf.put_slice(b"user\0pg\0"); buf.put_u8(0); let result = FrontendMessage::decode_startup(&mut buf);
1376 assert!(result.is_ok());
1377 let msg = result.unwrap();
1378 assert!(matches!(
1379 msg,
1380 Some(FrontendMessage::Startup { protocol_version, params })
1381 if protocol_version == 196608 && params.get("user") == Some(&"pg".to_string())
1382 ));
1383 }
1384
1385 #[test]
1390 fn test_buffer_properly_consumed_after_query() {
1391 let mut buf = BytesMut::new();
1392 buf.put_u8(b'Q');
1394 buf.put_i32(10);
1395 buf.put_slice(b"test1\0");
1396 buf.put_u8(b'Q');
1398 buf.put_i32(10);
1399 buf.put_slice(b"test2\0");
1400
1401 let result1 = FrontendMessage::decode(&mut buf);
1402 assert!(matches!(
1403 result1.unwrap(),
1404 Some(FrontendMessage::Query { query }) if query == "test1"
1405 ));
1406
1407 let result2 = FrontendMessage::decode(&mut buf);
1408 assert!(matches!(
1409 result2.unwrap(),
1410 Some(FrontendMessage::Query { query }) if query == "test2"
1411 ));
1412 }
1413
1414 #[test]
1415 fn test_buffer_not_consumed_on_incomplete() {
1416 let mut buf = BytesMut::new();
1417 buf.put_u8(b'Q');
1418 buf.put_i32(100); let original_len = buf.len();
1421 let result = FrontendMessage::decode(&mut buf);
1422 assert!(result.is_ok());
1423 assert!(result.unwrap().is_none());
1424 assert_eq!(buf.len(), original_len); }
1426
1427 #[test]
1432 fn test_very_large_declared_length() {
1433 let mut buf = BytesMut::new();
1434 buf.put_u8(b'Q');
1435 buf.put_i32(i32::MAX); buf.put_slice(b"small\0");
1437 let result = FrontendMessage::decode(&mut buf);
1438 assert!(result.is_ok());
1440 assert!(result.unwrap().is_none());
1441 }
1442
1443 #[test]
1448 fn test_password_message_valid() {
1449 let mut buf = BytesMut::new();
1450 buf.put_u8(b'p');
1451 buf.put_i32(13); buf.put_slice(b"secret\0");
1453 buf.put_slice(&[0, 0]);
1455 let result = FrontendMessage::decode(&mut buf);
1456 assert!(result.is_ok());
1457 assert!(matches!(
1458 result.unwrap(),
1459 Some(FrontendMessage::Password { password }) if password == "secret"
1460 ));
1461 }
1462
1463 #[test]
1464 fn test_password_message_empty() {
1465 let mut buf = BytesMut::new();
1466 buf.put_u8(b'p');
1467 buf.put_i32(5); buf.put_u8(0);
1469 let result = FrontendMessage::decode(&mut buf);
1470 assert!(result.is_ok());
1471 assert!(matches!(
1472 result.unwrap(),
1473 Some(FrontendMessage::Password { password }) if password.is_empty()
1474 ));
1475 }
1476
1477 #[test]
1482 fn test_subscribe_with_selective_updates_config_full() {
1483 let mut buf = BytesMut::new();
1485 buf.put_u8(0xF0); let mut body = BytesMut::new();
1489
1490 body.put_slice(b"SELECT * FROM test\0");
1492
1493 body.put_i16(0);
1495
1496 body.put_i16(0);
1498
1499 body.put_u8(0x07); body.put_u8(1); body.put_u16(5); body.put_f64(0.75); buf.put_i32((4 + body.len()) as i32);
1507 buf.put_slice(&body);
1508
1509 let result = FrontendMessage::decode(&mut buf);
1510 assert!(result.is_ok());
1511
1512 let msg = result.unwrap();
1513 assert!(matches!(msg, Some(FrontendMessage::Subscribe { .. })));
1514
1515 if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = msg {
1516 assert!(selective_updates_config.is_some());
1517 let config = selective_updates_config.unwrap();
1518 assert_eq!(config.enabled, Some(true));
1519 assert_eq!(config.min_changed_columns, Some(5));
1520 assert_eq!(config.max_changed_columns_ratio, Some(0.75));
1521 } else {
1522 panic!("Expected Subscribe message");
1523 }
1524 }
1525
1526 #[test]
1527 fn test_subscribe_with_partial_selective_config_enabled_only() {
1528 let mut buf = BytesMut::new();
1530 buf.put_u8(0xF0); let mut body = BytesMut::new();
1533 body.put_slice(b"SELECT * FROM test\0");
1534 body.put_i16(0); body.put_i16(0); body.put_u8(0x01); body.put_u8(1); buf.put_i32((4 + body.len()) as i32);
1541 buf.put_slice(&body);
1542
1543 let result = FrontendMessage::decode(&mut buf);
1544 assert!(result.is_ok());
1545
1546 if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1547 assert!(selective_updates_config.is_some());
1548 let config = selective_updates_config.unwrap();
1549 assert_eq!(config.enabled, Some(true));
1550 assert_eq!(config.min_changed_columns, None);
1551 assert_eq!(config.max_changed_columns_ratio, None);
1552 } else {
1553 panic!("Expected Subscribe message with config");
1554 }
1555 }
1556
1557 #[test]
1558 fn test_subscribe_with_partial_selective_config_min_columns_only() {
1559 let mut buf = BytesMut::new();
1561 buf.put_u8(0xF0); let mut body = BytesMut::new();
1564 body.put_slice(b"SELECT * FROM test\0");
1565 body.put_i16(0); body.put_i16(0); body.put_u8(0x02); body.put_u16(10); buf.put_i32((4 + body.len()) as i32);
1572 buf.put_slice(&body);
1573
1574 let result = FrontendMessage::decode(&mut buf);
1575 assert!(result.is_ok());
1576
1577 if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1578 assert!(selective_updates_config.is_some());
1579 let config = selective_updates_config.unwrap();
1580 assert_eq!(config.enabled, None);
1581 assert_eq!(config.min_changed_columns, Some(10));
1582 assert_eq!(config.max_changed_columns_ratio, None);
1583 } else {
1584 panic!("Expected Subscribe message with config");
1585 }
1586 }
1587
1588 #[test]
1589 fn test_subscribe_with_partial_selective_config_max_ratio_only() {
1590 let mut buf = BytesMut::new();
1592 buf.put_u8(0xF0); let mut body = BytesMut::new();
1595 body.put_slice(b"SELECT * FROM test\0");
1596 body.put_i16(0); body.put_i16(0); body.put_u8(0x04); body.put_f64(0.5); buf.put_i32((4 + body.len()) as i32);
1603 buf.put_slice(&body);
1604
1605 let result = FrontendMessage::decode(&mut buf);
1606 assert!(result.is_ok());
1607
1608 if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1609 assert!(selective_updates_config.is_some());
1610 let config = selective_updates_config.unwrap();
1611 assert_eq!(config.enabled, None);
1612 assert_eq!(config.min_changed_columns, None);
1613 assert_eq!(config.max_changed_columns_ratio, Some(0.5));
1614 } else {
1615 panic!("Expected Subscribe message with config");
1616 }
1617 }
1618
1619 #[test]
1620 fn test_subscribe_with_selective_config_zero_flags() {
1621 let mut buf = BytesMut::new();
1623 buf.put_u8(0xF0); let mut body = BytesMut::new();
1626 body.put_slice(b"SELECT * FROM test\0");
1627 body.put_i16(0); body.put_i16(0); body.put_u8(0x00); buf.put_i32((4 + body.len()) as i32);
1632 buf.put_slice(&body);
1633
1634 let result = FrontendMessage::decode(&mut buf);
1635 assert!(result.is_ok());
1636
1637 if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1638 assert!(selective_updates_config.is_none());
1639 } else {
1640 panic!("Expected Subscribe message");
1641 }
1642 }
1643
1644 #[test]
1645 fn test_subscribe_without_selective_config_field() {
1646 let mut buf = BytesMut::new();
1648 buf.put_u8(0xF0); let mut body = BytesMut::new();
1651 body.put_slice(b"SELECT * FROM test\0");
1652 body.put_i16(0); body.put_i16(0); buf.put_i32((4 + body.len()) as i32);
1657 buf.put_slice(&body);
1658
1659 let result = FrontendMessage::decode(&mut buf);
1660 assert!(result.is_ok());
1661
1662 if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1663 assert!(selective_updates_config.is_none());
1664 } else {
1665 panic!("Expected Subscribe message");
1666 }
1667 }
1668
1669 #[test]
1670 fn test_subscribe_with_config_disabled() {
1671 let mut buf = BytesMut::new();
1673 buf.put_u8(0xF0); let mut body = BytesMut::new();
1676 body.put_slice(b"SELECT * FROM test\0");
1677 body.put_i16(0); body.put_i16(0); body.put_u8(0x01); body.put_u8(0); buf.put_i32((4 + body.len()) as i32);
1684 buf.put_slice(&body);
1685
1686 let result = FrontendMessage::decode(&mut buf);
1687 assert!(result.is_ok());
1688
1689 if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1690 assert!(selective_updates_config.is_some());
1691 let config = selective_updates_config.unwrap();
1692 assert_eq!(config.enabled, Some(false));
1693 } else {
1694 panic!("Expected Subscribe message with config");
1695 }
1696 }
1697
1698 #[test]
1699 fn test_subscribe_with_combined_flags() {
1700 let mut buf = BytesMut::new();
1702 buf.put_u8(0xF0); let mut body = BytesMut::new();
1705 body.put_slice(b"SELECT * FROM test\0");
1706 body.put_i16(0); body.put_i16(0); body.put_u8(0x03); body.put_u8(1); body.put_u16(3); buf.put_i32((4 + body.len()) as i32);
1714 buf.put_slice(&body);
1715
1716 let result = FrontendMessage::decode(&mut buf);
1717 assert!(result.is_ok());
1718
1719 if let Some(FrontendMessage::Subscribe { selective_updates_config, .. }) = result.unwrap() {
1720 assert!(selective_updates_config.is_some());
1721 let config = selective_updates_config.unwrap();
1722 assert_eq!(config.enabled, Some(true));
1723 assert_eq!(config.min_changed_columns, Some(3));
1724 assert_eq!(config.max_changed_columns_ratio, None);
1725 } else {
1726 panic!("Expected Subscribe message with config");
1727 }
1728 }
1729 }
1730}