1#![allow(clippy::cast_possible_truncation)]
6
7use super::messages::{
8 CANCEL_REQUEST_CODE, DescribeKind, FrontendMessage, SSL_REQUEST_CODE, frontend_type,
9};
10
11#[derive(Debug, Clone)]
15pub struct MessageWriter {
16 buf: Vec<u8>,
18}
19
20impl Default for MessageWriter {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl MessageWriter {
27 pub fn new() -> Self {
29 Self::with_capacity(1024)
30 }
31
32 pub fn with_capacity(capacity: usize) -> Self {
34 Self {
35 buf: Vec::with_capacity(capacity),
36 }
37 }
38
39 pub fn clear(&mut self) {
41 self.buf.clear();
42 }
43
44 pub fn as_bytes(&self) -> &[u8] {
46 &self.buf
47 }
48
49 pub fn take(&mut self) -> Vec<u8> {
51 std::mem::take(&mut self.buf)
52 }
53
54 pub fn write(&mut self, msg: &FrontendMessage) -> &[u8] {
58 self.buf.clear();
59
60 match msg {
61 FrontendMessage::Startup { version, params } => {
62 self.write_startup(*version, params);
63 }
64 FrontendMessage::PasswordMessage(password) => {
65 self.write_password(password);
66 }
67 FrontendMessage::SASLInitialResponse { mechanism, data } => {
68 self.write_sasl_initial(mechanism, data);
69 }
70 FrontendMessage::SASLResponse(data) => {
71 self.write_sasl_response(data);
72 }
73 FrontendMessage::Query(query) => {
74 self.write_query(query);
75 }
76 FrontendMessage::Parse {
77 name,
78 query,
79 param_types,
80 } => {
81 self.write_parse(name, query, param_types);
82 }
83 FrontendMessage::Bind {
84 portal,
85 statement,
86 param_formats,
87 params,
88 result_formats,
89 } => {
90 self.write_bind(portal, statement, param_formats, params, result_formats);
91 }
92 FrontendMessage::Describe { kind, name } => {
93 self.write_describe(*kind, name);
94 }
95 FrontendMessage::Execute { portal, max_rows } => {
96 self.write_execute(portal, *max_rows);
97 }
98 FrontendMessage::Close { kind, name } => {
99 self.write_close(*kind, name);
100 }
101 FrontendMessage::Sync => {
102 self.write_sync();
103 }
104 FrontendMessage::Flush => {
105 self.write_flush();
106 }
107 FrontendMessage::CopyData(data) => {
108 self.write_copy_data(data);
109 }
110 FrontendMessage::CopyDone => {
111 self.write_copy_done();
112 }
113 FrontendMessage::CopyFail(message) => {
114 self.write_copy_fail(message);
115 }
116 FrontendMessage::Terminate => {
117 self.write_terminate();
118 }
119 FrontendMessage::CancelRequest {
120 process_id,
121 secret_key,
122 } => {
123 self.write_cancel_request(*process_id, *secret_key);
124 }
125 FrontendMessage::SSLRequest => {
126 self.write_ssl_request();
127 }
128 }
129
130 &self.buf
131 }
132
133 fn write_startup(&mut self, version: i32, params: &[(String, String)]) {
137 let mut body_len = 4; for (key, value) in params {
140 body_len += key.len() + 1 + value.len() + 1;
141 }
142 body_len += 1; let total_len = (body_len + 4) as i32;
146 self.buf.extend_from_slice(&total_len.to_be_bytes());
147
148 self.buf.extend_from_slice(&version.to_be_bytes());
150
151 for (key, value) in params {
153 self.buf.extend_from_slice(key.as_bytes());
154 self.buf.push(0);
155 self.buf.extend_from_slice(value.as_bytes());
156 self.buf.push(0);
157 }
158
159 self.buf.push(0);
161 }
162
163 fn write_password(&mut self, password: &str) {
165 self.write_simple_string_message(frontend_type::PASSWORD, password);
166 }
167
168 fn write_sasl_initial(&mut self, mechanism: &str, data: &[u8]) {
170 self.buf.push(frontend_type::PASSWORD);
172
173 let body_len = mechanism.len() + 1 + 4 + data.len();
175 let total_len = (body_len + 4) as i32;
176 self.buf.extend_from_slice(&total_len.to_be_bytes());
177
178 self.buf.extend_from_slice(mechanism.as_bytes());
180 self.buf.push(0);
181
182 if data.is_empty() {
184 self.buf.extend_from_slice(&(-1_i32).to_be_bytes());
185 } else {
186 let data_len = data.len() as i32;
187 self.buf.extend_from_slice(&data_len.to_be_bytes());
188 self.buf.extend_from_slice(data);
189 }
190 }
191
192 fn write_sasl_response(&mut self, data: &[u8]) {
194 self.buf.push(frontend_type::PASSWORD);
195 let len = (data.len() + 4) as i32;
196 self.buf.extend_from_slice(&len.to_be_bytes());
197 self.buf.extend_from_slice(data);
198 }
199
200 fn write_query(&mut self, query: &str) {
202 self.write_simple_string_message(frontend_type::QUERY, query);
203 }
204
205 fn write_parse(&mut self, name: &str, query: &str, param_types: &[u32]) {
207 self.buf.push(frontend_type::PARSE);
208
209 let body_len = name.len() + 1 + query.len() + 1 + 2 + (param_types.len() * 4);
211 let total_len = (body_len + 4) as i32;
212 self.buf.extend_from_slice(&total_len.to_be_bytes());
213
214 self.buf.extend_from_slice(name.as_bytes());
216 self.buf.push(0);
217
218 self.buf.extend_from_slice(query.as_bytes());
220 self.buf.push(0);
221
222 let num_params = param_types.len() as i16;
224 self.buf.extend_from_slice(&num_params.to_be_bytes());
225 for &oid in param_types {
226 self.buf.extend_from_slice(&oid.to_be_bytes());
227 }
228 }
229
230 fn write_bind(
232 &mut self,
233 portal: &str,
234 statement: &str,
235 param_formats: &[i16],
236 params: &[Option<Vec<u8>>],
237 result_formats: &[i16],
238 ) {
239 self.buf.push(frontend_type::BIND);
240
241 let mut body_len = portal.len() + 1 + statement.len() + 1;
243 body_len += 2 + (param_formats.len() * 2); body_len += 2; for param in params {
247 body_len += 4; if let Some(data) = param {
249 body_len += data.len();
250 }
251 }
252
253 body_len += 2 + (result_formats.len() * 2); let total_len = (body_len + 4) as i32;
256 self.buf.extend_from_slice(&total_len.to_be_bytes());
257
258 self.buf.extend_from_slice(portal.as_bytes());
260 self.buf.push(0);
261
262 self.buf.extend_from_slice(statement.as_bytes());
264 self.buf.push(0);
265
266 let num_formats = param_formats.len() as i16;
268 self.buf.extend_from_slice(&num_formats.to_be_bytes());
269 for &fmt in param_formats {
270 self.buf.extend_from_slice(&fmt.to_be_bytes());
271 }
272
273 let num_params = params.len() as i16;
275 self.buf.extend_from_slice(&num_params.to_be_bytes());
276 for param in params {
277 match param {
278 Some(data) => {
279 let len = data.len() as i32;
280 self.buf.extend_from_slice(&len.to_be_bytes());
281 self.buf.extend_from_slice(data);
282 }
283 None => {
284 self.buf.extend_from_slice(&(-1_i32).to_be_bytes());
286 }
287 }
288 }
289
290 let num_result_formats = result_formats.len() as i16;
292 self.buf
293 .extend_from_slice(&num_result_formats.to_be_bytes());
294 for &fmt in result_formats {
295 self.buf.extend_from_slice(&fmt.to_be_bytes());
296 }
297 }
298
299 fn write_describe(&mut self, kind: DescribeKind, name: &str) {
301 self.buf.push(frontend_type::DESCRIBE);
302 let body_len = 1 + name.len() + 1;
303 let total_len = (body_len + 4) as i32;
304 self.buf.extend_from_slice(&total_len.to_be_bytes());
305 self.buf.push(kind.as_byte());
306 self.buf.extend_from_slice(name.as_bytes());
307 self.buf.push(0);
308 }
309
310 fn write_execute(&mut self, portal: &str, max_rows: i32) {
312 self.buf.push(frontend_type::EXECUTE);
313 let body_len = portal.len() + 1 + 4;
314 let total_len = (body_len + 4) as i32;
315 self.buf.extend_from_slice(&total_len.to_be_bytes());
316 self.buf.extend_from_slice(portal.as_bytes());
317 self.buf.push(0);
318 self.buf.extend_from_slice(&max_rows.to_be_bytes());
319 }
320
321 fn write_close(&mut self, kind: DescribeKind, name: &str) {
323 self.buf.push(frontend_type::CLOSE);
324 let body_len = 1 + name.len() + 1;
325 let total_len = (body_len + 4) as i32;
326 self.buf.extend_from_slice(&total_len.to_be_bytes());
327 self.buf.push(kind.as_byte());
328 self.buf.extend_from_slice(name.as_bytes());
329 self.buf.push(0);
330 }
331
332 fn write_sync(&mut self) {
334 self.write_empty_message(frontend_type::SYNC);
335 }
336
337 fn write_flush(&mut self) {
339 self.write_empty_message(frontend_type::FLUSH);
340 }
341
342 fn write_copy_data(&mut self, data: &[u8]) {
344 self.buf.push(frontend_type::COPY_DATA);
345 let len = (data.len() + 4) as i32;
346 self.buf.extend_from_slice(&len.to_be_bytes());
347 self.buf.extend_from_slice(data);
348 }
349
350 fn write_copy_done(&mut self) {
352 self.write_empty_message(frontend_type::COPY_DONE);
353 }
354
355 fn write_copy_fail(&mut self, message: &str) {
357 self.write_simple_string_message(frontend_type::COPY_FAIL, message);
358 }
359
360 fn write_terminate(&mut self) {
362 self.write_empty_message(frontend_type::TERMINATE);
363 }
364
365 fn write_cancel_request(&mut self, process_id: i32, secret_key: i32) {
367 self.buf.extend_from_slice(&16_i32.to_be_bytes());
369 self.buf
371 .extend_from_slice(&CANCEL_REQUEST_CODE.to_be_bytes());
372 self.buf.extend_from_slice(&process_id.to_be_bytes());
374 self.buf.extend_from_slice(&secret_key.to_be_bytes());
376 }
377
378 fn write_ssl_request(&mut self) {
380 self.buf.extend_from_slice(&8_i32.to_be_bytes());
382 self.buf.extend_from_slice(&SSL_REQUEST_CODE.to_be_bytes());
384 }
385
386 fn write_empty_message(&mut self, type_byte: u8) {
390 self.buf.push(type_byte);
391 self.buf.extend_from_slice(&4_i32.to_be_bytes());
392 }
393
394 fn write_simple_string_message(&mut self, type_byte: u8, s: &str) {
396 self.buf.push(type_byte);
397 let len = (s.len() + 5) as i32; self.buf.extend_from_slice(&len.to_be_bytes());
399 self.buf.extend_from_slice(s.as_bytes());
400 self.buf.push(0);
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use crate::protocol::PROTOCOL_VERSION;
408
409 #[test]
410 fn test_startup_message() {
411 let mut writer = MessageWriter::new();
412 let msg = FrontendMessage::Startup {
413 version: PROTOCOL_VERSION,
414 params: vec![
415 ("user".to_string(), "postgres".to_string()),
416 ("database".to_string(), "test".to_string()),
417 ],
418 };
419
420 let data = writer.write(&msg);
421
422 let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
424 assert!(len > 0);
425
426 let version = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
427 assert_eq!(version, PROTOCOL_VERSION);
428
429 assert!(data.ends_with(&[0]));
431 }
432
433 #[test]
434 fn test_query_message() {
435 let mut writer = MessageWriter::new();
436 let msg = FrontendMessage::Query("SELECT 1".to_string());
437
438 let data = writer.write(&msg);
439
440 assert_eq!(data[0], b'Q');
441 let len = i32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
442 assert_eq!(len, 4 + 8 + 1); assert_eq!(data[len], 0);
446 }
447
448 #[test]
449 fn test_sync_message() {
450 let mut writer = MessageWriter::new();
451 let msg = FrontendMessage::Sync;
452
453 let data = writer.write(&msg);
454
455 assert_eq!(data, &[b'S', 0, 0, 0, 4]);
456 }
457
458 #[test]
459 fn test_flush_message() {
460 let mut writer = MessageWriter::new();
461 let msg = FrontendMessage::Flush;
462
463 let data = writer.write(&msg);
464
465 assert_eq!(data, &[b'H', 0, 0, 0, 4]);
466 }
467
468 #[test]
469 fn test_terminate_message() {
470 let mut writer = MessageWriter::new();
471 let msg = FrontendMessage::Terminate;
472
473 let data = writer.write(&msg);
474
475 assert_eq!(data, &[b'X', 0, 0, 0, 4]);
476 }
477
478 #[test]
479 fn test_parse_message() {
480 let mut writer = MessageWriter::new();
481 let msg = FrontendMessage::Parse {
482 name: "stmt1".to_string(),
483 query: "SELECT $1".to_string(),
484 param_types: vec![23], };
486
487 let data = writer.write(&msg);
488
489 assert_eq!(data[0], b'P');
490
491 let name_start = 5;
493 let name_end = data[name_start..].iter().position(|&b| b == 0).unwrap() + name_start;
494 assert_eq!(&data[name_start..name_end], b"stmt1");
495 }
496
497 #[test]
498 fn test_describe_statement() {
499 let mut writer = MessageWriter::new();
500 let msg = FrontendMessage::Describe {
501 kind: DescribeKind::Statement,
502 name: "stmt1".to_string(),
503 };
504
505 let data = writer.write(&msg);
506
507 assert_eq!(data[0], b'D');
508 assert_eq!(data[5], b'S'); }
510
511 #[test]
512 fn test_describe_portal() {
513 let mut writer = MessageWriter::new();
514 let msg = FrontendMessage::Describe {
515 kind: DescribeKind::Portal,
516 name: "portal1".to_string(),
517 };
518
519 let data = writer.write(&msg);
520
521 assert_eq!(data[0], b'D');
522 assert_eq!(data[5], b'P'); }
524
525 #[test]
526 fn test_execute_message() {
527 let mut writer = MessageWriter::new();
528 let msg = FrontendMessage::Execute {
529 portal: String::new(),
530 max_rows: 0,
531 };
532
533 let data = writer.write(&msg);
534
535 assert_eq!(data[0], b'E');
536
537 let max_rows_offset = 5 + 1; let max_rows = i32::from_be_bytes([
540 data[max_rows_offset],
541 data[max_rows_offset + 1],
542 data[max_rows_offset + 2],
543 data[max_rows_offset + 3],
544 ]);
545 assert_eq!(max_rows, 0);
546 }
547
548 #[test]
549 fn test_cancel_request() {
550 let mut writer = MessageWriter::new();
551 let msg = FrontendMessage::CancelRequest {
552 process_id: 12345,
553 secret_key: 67890,
554 };
555
556 let data = writer.write(&msg);
557
558 let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
560 assert_eq!(len, 16);
561
562 let code = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
564 assert_eq!(code, CANCEL_REQUEST_CODE);
565
566 let pid = i32::from_be_bytes([data[8], data[9], data[10], data[11]]);
568 assert_eq!(pid, 12345);
569
570 let key = i32::from_be_bytes([data[12], data[13], data[14], data[15]]);
572 assert_eq!(key, 67890);
573 }
574
575 #[test]
576 fn test_ssl_request() {
577 let mut writer = MessageWriter::new();
578 let msg = FrontendMessage::SSLRequest;
579
580 let data = writer.write(&msg);
581
582 let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
583 assert_eq!(len, 8);
584
585 let code = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
586 assert_eq!(code, SSL_REQUEST_CODE);
587 }
588
589 #[test]
590 fn test_bind_with_null_params() {
591 let mut writer = MessageWriter::new();
592 let msg = FrontendMessage::Bind {
593 portal: String::new(),
594 statement: "stmt1".to_string(),
595 param_formats: vec![0],
596 params: vec![None], result_formats: vec![],
598 };
599
600 let data = writer.write(&msg);
601 assert_eq!(data[0], b'B');
602
603 let null_indicator = (-1_i32).to_be_bytes();
605 assert!(data.windows(4).any(|w| w == null_indicator));
606 }
607
608 #[test]
609 fn test_copy_data() {
610 let mut writer = MessageWriter::new();
611 let payload = b"hello\nworld\n";
612 let msg = FrontendMessage::CopyData(payload.to_vec());
613
614 let data = writer.write(&msg);
615
616 assert_eq!(data[0], b'd');
617 let len = i32::from_be_bytes([data[1], data[2], data[3], data[4]]);
618 assert_eq!(len, (4 + payload.len()) as i32);
619 assert_eq!(&data[5..], payload);
620 }
621
622 #[test]
623 fn test_writer_reuse() {
624 let mut writer = MessageWriter::new();
625
626 writer.write(&FrontendMessage::Sync);
628 assert_eq!(writer.as_bytes(), &[b'S', 0, 0, 0, 4]);
629
630 writer.write(&FrontendMessage::Flush);
632 assert_eq!(writer.as_bytes(), &[b'H', 0, 0, 0, 4]);
633 }
634}