1use std::collections::HashMap;
2use std::fmt;
3use std::str::Utf8Error;
4
5use aws_sdk_s3::operation::delete_object::DeleteObjectError;
6use aws_sdk_s3::operation::get_object::{GetObjectError, GetObjectOutput};
7use aws_sdk_s3::operation::put_object::{PutObjectError, PutObjectOutput};
8use aws_sdk_s3::primitives::ByteStreamError;
9use aws_sdk_sqs::operation::change_message_visibility::builders::ChangeMessageVisibilityFluentBuilder;
10use aws_sdk_sqs::operation::change_message_visibility::{
11 ChangeMessageVisibilityError, ChangeMessageVisibilityOutput,
12};
13use aws_sdk_sqs::operation::delete_message::builders::DeleteMessageFluentBuilder;
14use aws_sdk_sqs::operation::delete_message::{DeleteMessageError, DeleteMessageOutput};
15use aws_sdk_sqs::operation::receive_message::builders::ReceiveMessageFluentBuilder;
16use aws_sdk_sqs::operation::receive_message::{ReceiveMessageError, ReceiveMessageOutput};
17use aws_sdk_sqs::operation::send_message::builders::SendMessageFluentBuilder;
18use aws_sdk_sqs::operation::send_message::{SendMessageError, SendMessageOutput};
19use aws_sdk_sqs::types::Message;
20use aws_sdk_sqs::types::MessageAttributeValue;
21use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
22use aws_smithy_runtime_api::client::result::SdkError;
23use aws_smithy_runtime_api::http::Response;
24use aws_smithy_types::byte_stream::ByteStream;
25use aws_smithy_types::error::operation::BuildError;
26use regex::Regex;
27use serde::{Deserialize, Serialize};
28use serde_json::Result as SerdeJsonResult;
29use uuid::Uuid;
30
31const MAX_MESSAGE_SIZE_IN_BYTES: usize = 262144;
32static DEFAULT_POINTER_CLASS: &str = "software.amazon.payloadoffloading.PayloadS3Pointer";
33static LEGACY_RESERVED_ATTRIBUTE_NAME: &str = "SQSLargePayloadSize";
34
35pub struct SqsExtendedClientBuilder {
38 s3_client: aws_sdk_s3::Client,
39 bucket_name: Option<String>,
40 message_size_threshold: usize,
41 batch_message_size_threshold: usize,
42 always_s3: bool,
43 pointer_class: String,
44 reserved_attributes: Vec<String>,
45 object_prefix: String,
46}
47
48impl SqsExtendedClientBuilder {
49 pub fn new(s3_client: aws_sdk_s3::Client) -> SqsExtendedClientBuilder {
50 SqsExtendedClientBuilder {
51 s3_client,
52 bucket_name: None,
53 message_size_threshold: MAX_MESSAGE_SIZE_IN_BYTES,
54 batch_message_size_threshold: MAX_MESSAGE_SIZE_IN_BYTES,
55 always_s3: false,
56 pointer_class: DEFAULT_POINTER_CLASS.to_string(),
57 reserved_attributes: vec![
58 "ExtendedPayloadSize".to_string(),
59 LEGACY_RESERVED_ATTRIBUTE_NAME.to_string(),
60 ],
61 object_prefix: "".to_string(),
62 }
63 }
64
65 pub fn with_s3_bucket_name(mut self, bucket_name: String) -> SqsExtendedClientBuilder {
66 self.bucket_name = Some(bucket_name);
67 self
68 }
69
70 pub fn with_message_size_threshold(mut self, msg_size: usize) -> SqsExtendedClientBuilder {
71 self.message_size_threshold = msg_size;
72 self
73 }
74
75 pub fn with_batch_message_size_threshold(
76 mut self,
77 batch_msg_size: usize,
78 ) -> SqsExtendedClientBuilder {
79 self.batch_message_size_threshold = batch_msg_size;
80 self
81 }
82
83 pub fn with_always_through_s3(mut self, always_s3: bool) -> SqsExtendedClientBuilder {
84 self.always_s3 = always_s3;
85 self
86 }
87
88 pub fn with_reserved_attribute_names(
89 mut self,
90 reserved_attribute_names: Vec<String>,
91 ) -> SqsExtendedClientBuilder {
92 self.reserved_attributes = reserved_attribute_names;
93 self
94 }
95
96 pub fn with_pointer_class(mut self, pointer_class: String) -> SqsExtendedClientBuilder {
97 self.pointer_class = pointer_class;
98 self
99 }
100
101 pub fn with_object_prefix(mut self, prefix: String) -> SqsExtendedClientBuilder {
102 self.object_prefix = prefix;
103 self
104 }
105
106 pub fn build(self) -> SqsExtendedClient {
107 let receipt_handler_regex: Regex = Regex::new(r"^-\.\.s3BucketName\.\.-(.*)-\.\.s3BucketName\.\.--\.\.s3Key\.\.-(.*)-\.\.s3Key\.\.-(.*)").unwrap();
108
109 SqsExtendedClient {
110 s3_client: self.s3_client,
111 bucket_name: self.bucket_name,
112 message_size_threshold: self.message_size_threshold,
113 always_through_s3: self.always_s3,
114 pointer_class: self.pointer_class,
115 reserved_attributes: self.reserved_attributes,
116 object_prefix: self.object_prefix,
117 extended_receipt_handler_regex: receipt_handler_regex,
118 }
119 }
120}
121
122pub struct SqsExtendedClient {
125 s3_client: aws_sdk_s3::Client,
126 bucket_name: Option<String>,
127 message_size_threshold: usize,
128 always_through_s3: bool,
129 pointer_class: String,
130 reserved_attributes: Vec<String>,
131 object_prefix: String,
132 extended_receipt_handler_regex: Regex,
133}
134
135impl SqsExtendedClient {
136 pub async fn send_message(
137 &self,
138 msg_input: SendMessageFluentBuilder,
139 ) -> Result<SendMessageOutput, SqsExtendedClientError> {
140 let Some(bn) = &self.bucket_name else {
141 return Err(SqsExtendedClientError::NoBucketName);
142 };
143 let bucket_name: String = bn.to_string();
144
145 let Some(msg_bdy) = msg_input.get_message_body() else {
146 return Err(SqsExtendedClientError::NoMessageBody);
147 };
148 let message_body: &str = msg_bdy;
149
150 let result: Result<SendMessageOutput, SdkError<SendMessageError, Response>> = if self
151 .always_through_s3
152 || self.message_exceeds_threshold(message_body, msg_input.get_message_attributes())
153 {
154 let s3_key: String = self.s3_key(Uuid::new_v4().to_string());
155
156 let s3_result: Result<PutObjectOutput, SdkError<PutObjectError, HttpResponse>> = self
157 .s3_client
158 .put_object()
159 .bucket(&bucket_name)
160 .key(&s3_key)
161 .body(ByteStream::from(message_body.as_bytes().to_vec()))
162 .send()
163 .await;
164
165 if let Err(s3_error) = s3_result {
166 return Err(SqsExtendedClientError::S3Upload(s3_error));
167 }
168
169 let new_msg: S3Pointer = S3Pointer {
170 s3_bucket_name: bucket_name,
171 s3_key,
172 class: self.pointer_class.clone(),
173 };
174
175 let message_body_size: usize = message_body.len();
176
177 let reserved_attribute: MessageAttributeValue = MessageAttributeValue::builder()
178 .data_type("Number")
179 .string_value(message_body_size.to_string())
180 .build()?;
181
182 msg_input
183 .message_body(new_msg.marshall_json())
184 .message_attributes(self.reserved_attributes[0].clone(), reserved_attribute)
185 .send()
186 .await
187 } else {
188 msg_input.send().await
189 };
190
191 result.map_err(SqsExtendedClientError::SqsSendMessage)
192 }
193
194 pub async fn receive_message(
195 &self,
196 receive_message_builder: ReceiveMessageFluentBuilder,
197 ) -> Result<ReceiveMessageOutput, SqsExtendedClientError> {
198 let mut sqs_response: ReceiveMessageOutput = receive_message_builder
199 .message_attribute_names("All")
200 .send()
201 .await?;
202
203 let mut messages: Vec<Message> = match &sqs_response.messages {
204 None => return Ok(sqs_response),
205 Some(msgs) => msgs.clone(),
206 };
207
208 for msg in messages.iter_mut() {
209 let mut found: bool = false;
210
211 for rsrvd_attr in self.reserved_attributes.iter() {
212 let msg_attrs: HashMap<String, MessageAttributeValue> =
213 match &msg.message_attributes {
214 None => break,
215 Some(ma) => ma.clone(),
216 };
217
218 if msg_attrs.contains_key(rsrvd_attr.as_str()) {
219 found = true;
220 break;
221 }
222 }
223
224 if !found {
225 continue;
226 }
227
228 let body: String = match &msg.body {
229 None => return Err(SqsExtendedClientError::NoMessageBody),
230 Some(b) => b.to_string(),
231 };
232
233 let receipt_handle: String = match &msg.receipt_handle {
234 None => return Err(SqsExtendedClientError::NoReceiptHandle),
235 Some(rh) => rh.to_string(),
236 };
237
238 let s3_pointer = S3Pointer::unmarshall_json(&body)?;
239
240 let object: GetObjectOutput = self
241 .s3_client
242 .get_object()
243 .bucket(s3_pointer.s3_bucket_name.clone())
244 .key(s3_pointer.s3_key.clone())
245 .send()
246 .await?;
247
248 let bytes = object.body.collect().await?.into_bytes();
249 let response: &str = std::str::from_utf8(&bytes)?;
250
251 msg.body = Some(response.to_string());
252 msg.receipt_handle = Some(Self::new_extended_receipt_handle(
253 s3_pointer.s3_bucket_name.clone(),
254 s3_pointer.s3_key.clone(),
255 receipt_handle,
256 ))
257 }
258
259 sqs_response.messages = Some(messages);
260 Ok(sqs_response)
261 }
262
263 pub async fn delete_message(
264 &self,
265 mut delete_message_builder: DeleteMessageFluentBuilder,
266 ) -> Result<DeleteMessageOutput, SqsExtendedClientError> {
267 let receipt_handle: String = match delete_message_builder.get_receipt_handle() {
268 None => return Err(SqsExtendedClientError::NoReceiptHandle),
269 Some(rh) => rh.to_string(),
270 };
271
272 let (bucket, key, handle) = self.parse_extended_receipt_handle(receipt_handle.clone());
273
274 if !bucket.is_empty() && !key.is_empty() && !handle.is_empty() {
275 delete_message_builder =
276 delete_message_builder.set_receipt_handle(Some(handle.clone()));
277 }
278
279 let resp: DeleteMessageOutput = delete_message_builder.send().await?;
280
281 if !bucket.is_empty() && !key.is_empty() {
282 self.s3_client
283 .delete_object()
284 .bucket(bucket)
285 .key(key)
286 .send()
287 .await?;
288 }
289
290 Ok(resp)
291 }
292
293 pub async fn change_message_visibility(
294 &self,
295 mut change_message_visibility: ChangeMessageVisibilityFluentBuilder,
296 ) -> Result<ChangeMessageVisibilityOutput, SqsExtendedClientError> {
297 let receipt_handle: String = match change_message_visibility.get_receipt_handle() {
298 None => return Err(SqsExtendedClientError::NoReceiptHandle),
299 Some(rh) => rh.to_string(),
300 };
301
302 let (bucket, key, handle) = self.parse_extended_receipt_handle(receipt_handle.clone());
303
304 if !bucket.is_empty() && !key.is_empty() && !handle.is_empty() {
305 change_message_visibility =
306 change_message_visibility.set_receipt_handle(Some(handle.clone()));
307 }
308
309 let resp: ChangeMessageVisibilityOutput = change_message_visibility.send().await?;
310
311 Ok(resp)
312 }
313
314 fn message_exceeds_threshold(
315 &self,
316 body: &str,
317 attributes: &Option<HashMap<String, MessageAttributeValue>>,
318 ) -> bool {
319 self.message_size(body, attributes).total() > self.message_size_threshold
320 }
321
322 fn message_size(
323 &self,
324 body: &str,
325 attributes: &Option<HashMap<String, MessageAttributeValue>>,
326 ) -> MessageSize {
327 MessageSize {
328 body_size: body.len(),
329 attribute_size: self.attribute_size(attributes),
330 }
331 }
332
333 fn attribute_size(&self, attributes: &Option<HashMap<String, MessageAttributeValue>>) -> usize {
334 match attributes {
335 None => 0,
336 Some(hash_map) => self.calc_attribute_size(hash_map),
337 }
338 }
339
340 fn calc_attribute_size(&self, attributes: &HashMap<String, MessageAttributeValue>) -> usize {
341 let mut sum: usize = 0;
342 for (k, v) in attributes {
343 sum += k.len();
344
345 match &v.binary_value {
346 None => {}
347 Some(blob) => {
348 sum += blob.as_ref().len();
349 }
350 }
351 match &v.string_value {
352 None => {}
353 Some(string) => sum += string.len(),
354 }
355 sum += v.data_type.len();
356 }
357 sum
358 }
359
360 fn s3_key(&self, filename: String) -> String {
361 if !self.object_prefix.is_empty() {
362 return format!("{}/{}", self.object_prefix, filename);
363 }
364 filename
365 }
366
367 fn new_extended_receipt_handle(bucket: String, key: String, handle: String) -> String {
368 let s3_bucket_name_marker: String = "-..s3BucketName..-".to_string();
369 let s3_key_marker: String = "-..s3Key..-".to_string();
370
371 format!(
372 "{}{}{}{}{}{}{}",
373 s3_bucket_name_marker,
374 bucket,
375 s3_bucket_name_marker,
376 s3_key_marker,
377 key,
378 s3_key_marker,
379 handle
380 )
381 .to_string()
382 }
383
384 fn parse_extended_receipt_handle(
385 &self,
386 extended_receipt_handle: String,
387 ) -> (String, String, String) {
388 let caps = match self
389 .extended_receipt_handler_regex
390 .captures(&extended_receipt_handle)
391 {
392 Some(caps) => caps,
393 None => return ("".to_string(), "".to_string(), "".to_string()),
394 };
395
396 if caps.len() != 4 {
397 return ("".to_string(), "".to_string(), "".to_string());
398 }
399
400 let bucket = caps
401 .get(1)
402 .map(|m| m.as_str().to_string())
403 .unwrap_or_default();
404 let key = caps
405 .get(2)
406 .map(|m| m.as_str().to_string())
407 .unwrap_or_default();
408 let receipt_handle = caps
409 .get(3)
410 .map(|m| m.as_str().to_string())
411 .unwrap_or_default();
412
413 (bucket, key, receipt_handle)
414 }
415}
416
417#[derive(Serialize, Deserialize, Debug)]
420struct S3PointerArray(String, S3PointerBucketAndKeyObject);
421
422#[derive(Serialize, Deserialize, Debug)]
423struct S3PointerBucketAndKeyObject {
424 #[serde(rename = "s3BucketName")]
425 s3_bucket_name: String,
426 #[serde(rename = "s3Key")]
427 s3_key: String,
428}
429
430#[derive(Serialize, Deserialize, Debug)]
431struct S3Pointer {
432 s3_bucket_name: String,
433 s3_key: String,
434 class: String,
435}
436
437impl S3Pointer {
438 fn marshall_json(self) -> String {
439 format!(
440 "[\"{}\",{{\"s3BucketName\":\"{}\",\"s3Key\":\"{}\"}}]",
441 self.class, self.s3_bucket_name, self.s3_key
442 )
443 }
444
445 fn unmarshall_json(input: &str) -> SerdeJsonResult<S3Pointer> {
446 let wrapper: S3PointerArray = serde_json::from_str(input)?;
447
448 let s3_pointer: S3Pointer = S3Pointer {
449 s3_bucket_name: wrapper.1.s3_bucket_name,
450 s3_key: wrapper.1.s3_key,
451 class: wrapper.0,
452 };
453
454 Ok(s3_pointer)
455 }
456}
457
458struct MessageSize {
461 body_size: usize,
462 attribute_size: usize,
463}
464
465impl MessageSize {
466 fn total(self) -> usize {
467 self.body_size + self.attribute_size
468 }
469}
470
471#[derive(Debug)]
474pub enum SqsExtendedClientError {
475 S3Upload(SdkError<PutObjectError, HttpResponse>),
476 S3Download(SdkError<GetObjectError, Response>),
477 S3DeleteObject(SdkError<DeleteObjectError, Response>),
478 S3DownloadToBytes(ByteStreamError),
479 S3DownloadToUtf8(Utf8Error),
480 SqsSendMessage(SdkError<SendMessageError, HttpResponse>),
481 SqsReceiveMessage(SdkError<ReceiveMessageError, HttpResponse>),
482 SqsDeleteMessage(SdkError<DeleteMessageError, Response>),
483 SqsChangeMessageVisibility(SdkError<ChangeMessageVisibilityError, Response>),
484 SqsBuildMessageAttribute(BuildError),
485 SqsReceiveMessageUnMarshallMessageBody(serde_json::Error),
486 NoBucketName,
487 NoMessageBody,
488 NoReceiptHandle,
489}
490
491impl fmt::Display for SqsExtendedClientError {
492 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
493 match self {
494 Self::S3Upload(err) => write!(f, "S3 upload failed: {}", err),
495 Self::S3Download(err) => write!(f, "S3 download failed: {}", err),
496 Self::S3DeleteObject(err) => write!(f, "S3 delete failed: {}", err),
497 Self::S3DownloadToBytes(err) => write!(f, "S3 Byte Stream Error: {}", err),
498 Self::S3DownloadToUtf8(err) => write!(f, "S3 Byte Stream Error: {}", err),
499 Self::SqsSendMessage(err) => write!(f, "SQS operation failed: {}", err),
500 Self::SqsReceiveMessage(err) => write!(f, "SQS operation failed: {}", err),
501 Self::SqsDeleteMessage(err) => write!(f, "SQS delete failed: {}", err),
502 Self::SqsChangeMessageVisibility(err) => {
503 write!(f, "SQS change message visibilty failed: {}", err)
504 }
505 Self::SqsBuildMessageAttribute(err) => {
506 write!(f, "SQS build message attribute failed: {}", err)
507 }
508 Self::SqsReceiveMessageUnMarshallMessageBody(err) => {
509 write!(f, "Failed to marshall sqs message body: {}", err)
510 }
511 Self::NoBucketName => write!(f, "No bucket name configured"),
512 Self::NoMessageBody => write!(f, "No message body"),
513 Self::NoReceiptHandle => write!(f, "No receipt handle"),
514 }
515 }
516}
517
518impl From<aws_sdk_s3::error::BuildError> for SqsExtendedClientError {
519 fn from(err: aws_sdk_s3::error::BuildError) -> Self {
520 Self::SqsBuildMessageAttribute(err)
521 }
522}
523
524impl From<Utf8Error> for SqsExtendedClientError {
525 fn from(err: Utf8Error) -> Self {
526 Self::S3DownloadToUtf8(err)
527 }
528}
529
530impl From<ByteStreamError> for SqsExtendedClientError {
531 fn from(err: ByteStreamError) -> Self {
532 Self::S3DownloadToBytes(err)
533 }
534}
535
536impl From<SdkError<GetObjectError, Response>> for SqsExtendedClientError {
537 fn from(err: SdkError<GetObjectError, Response>) -> Self {
538 Self::S3Download(err)
539 }
540}
541
542impl From<SdkError<DeleteObjectError, Response>> for SqsExtendedClientError {
543 fn from(err: SdkError<DeleteObjectError, Response>) -> Self {
544 Self::S3DeleteObject(err)
545 }
546}
547
548impl From<SdkError<ReceiveMessageError, HttpResponse>> for SqsExtendedClientError {
549 fn from(err: SdkError<ReceiveMessageError, HttpResponse>) -> Self {
550 Self::SqsReceiveMessage(err)
551 }
552}
553
554impl From<SdkError<DeleteMessageError, Response>> for SqsExtendedClientError {
555 fn from(err: SdkError<DeleteMessageError, Response>) -> Self {
556 Self::SqsDeleteMessage(err)
557 }
558}
559
560impl From<SdkError<ChangeMessageVisibilityError, Response>> for SqsExtendedClientError {
561 fn from(err: SdkError<ChangeMessageVisibilityError, Response>) -> Self {
562 Self::SqsChangeMessageVisibility(err)
563 }
564}
565
566impl From<serde_json::Error> for SqsExtendedClientError {
567 fn from(err: serde_json::Error) -> Self {
568 Self::SqsReceiveMessageUnMarshallMessageBody(err)
569 }
570}
571
572impl std::error::Error for SqsExtendedClientError {}
573
574#[cfg(test)]
577mod tests {
578 use aws_config::BehaviorVersion;
579 use aws_smithy_types::Blob;
580
581 use super::*;
582
583 fn make_test_credentials() -> aws_sdk_s3::config::Credentials {
584 aws_sdk_s3::config::Credentials::new(
585 "TEST_ACCESS_KEY_ID",
586 "TEST_SECRET_ACCESS_KEY",
587 Some("TEST_SESSION_TOKEN".to_string()),
588 None,
589 "",
590 )
591 }
592
593 fn make_test_s3_client() -> aws_sdk_s3::client::Client {
594 aws_sdk_s3::Client::from_conf(
595 aws_sdk_s3::Config::builder()
596 .behavior_version(BehaviorVersion::latest())
597 .credentials_provider(make_test_credentials())
598 .build(),
599 )
600 }
601
602 fn make_test_sqs_client() -> aws_sdk_sqs::client::Client {
603 aws_sdk_sqs::Client::from_conf(
604 aws_sdk_sqs::Config::builder()
605 .behavior_version(BehaviorVersion::latest())
606 .credentials_provider(make_test_credentials())
607 .build(),
608 )
609 }
610
611 #[test]
612 fn test_builder_fns() {
613 let sqs_extended_client: SqsExtendedClient =
614 SqsExtendedClientBuilder::new(make_test_s3_client())
615 .with_s3_bucket_name("bucket-name".to_string())
616 .with_message_size_threshold(9999)
617 .with_batch_message_size_threshold(1000)
618 .with_always_through_s3(true)
619 .with_reserved_attribute_names(vec!["attr_one".to_string(), "attr_two".to_string()])
620 .with_pointer_class("pointer-class".to_string())
621 .with_object_prefix("object-prefix".to_string())
622 .build();
623
624 let bucket_name: String = sqs_extended_client.bucket_name.unwrap_or_default();
625 assert_eq!("bucket-name", bucket_name);
626 assert_eq!(9999, sqs_extended_client.message_size_threshold);
627 assert!(sqs_extended_client.always_through_s3);
628 assert_eq!(
629 vec!["attr_one".to_string(), "attr_two".to_string()],
630 sqs_extended_client.reserved_attributes
631 );
632 assert_eq!("pointer-class", sqs_extended_client.pointer_class);
633 assert_eq!("object-prefix", sqs_extended_client.object_prefix);
634 }
635
636 #[test]
637 fn test_builder_defaults() {
638 let sqs_extended_client: SqsExtendedClient =
639 SqsExtendedClientBuilder::new(make_test_s3_client()).build();
640
641 let bucket_name: String = sqs_extended_client.bucket_name.unwrap_or_default();
642
643 assert_eq!("", bucket_name);
644 assert_eq!(
645 MAX_MESSAGE_SIZE_IN_BYTES,
646 sqs_extended_client.message_size_threshold
647 );
648 assert!(!sqs_extended_client.always_through_s3);
649 assert_eq!(
650 vec![
651 "ExtendedPayloadSize".to_string(),
652 LEGACY_RESERVED_ATTRIBUTE_NAME.to_string(),
653 ],
654 sqs_extended_client.reserved_attributes
655 );
656 assert_eq!(DEFAULT_POINTER_CLASS, sqs_extended_client.pointer_class);
657 assert_eq!("", sqs_extended_client.object_prefix);
658 }
659
660 #[test]
666 fn test_message_exceeds_threshold() {
667 let sqs_extended_client_small_msg_size_threshold: SqsExtendedClient =
669 SqsExtendedClientBuilder::new(make_test_s3_client())
670 .with_message_size_threshold(8)
671 .build();
672
673 let msg: SendMessageFluentBuilder = make_test_sqs_client()
674 .send_message()
675 .queue_url("queue_url")
676 .message_body("hello world");
677
678 let msg_body: &String = msg.get_message_body().as_ref().unwrap();
679
680 let result: bool = sqs_extended_client_small_msg_size_threshold
681 .message_exceeds_threshold(msg_body, msg.get_message_attributes());
682
683 assert!(result);
684
685 let sqs_extended_client_default: SqsExtendedClient =
686 SqsExtendedClientBuilder::new(make_test_s3_client()).build();
687
688 let result_2: bool = sqs_extended_client_default
689 .message_exceeds_threshold(msg_body, msg.get_message_attributes());
690
691 assert!(!result_2);
692 }
693
694 #[test]
695 fn test_message_size() {
696 let sqs_extended_client: SqsExtendedClient =
697 SqsExtendedClientBuilder::new(make_test_s3_client()).build();
698
699 let msg: SendMessageFluentBuilder = make_test_sqs_client()
700 .send_message()
701 .queue_url("queue_url")
702 .message_body("hello world");
703
704 let msg_body: &String = msg.get_message_body().as_ref().unwrap();
705
706 let result: MessageSize =
707 sqs_extended_client.message_size(msg_body, msg.get_message_attributes());
708
709 assert!(result.total() == 11);
710
711 let attribute_value_number: MessageAttributeValue = MessageAttributeValue::builder()
712 .data_type("Number")
713 .string_value("12")
714 .build()
715 .expect("Failed to build MessageAttributeValue");
716
717 let attribute_value_binary: MessageAttributeValue = MessageAttributeValue::builder()
718 .data_type("Binary")
719 .binary_value(Blob::new("IT'S BINARY HONEST"))
720 .build()
721 .expect("Failed to build MessageAttributeValue");
722
723 let msg_with_attributes: SendMessageFluentBuilder = make_test_sqs_client()
724 .send_message()
725 .queue_url("queue_url")
726 .message_body("hello world")
727 .message_attributes("test_attribute_one", attribute_value_number)
728 .message_attributes("test_attribute_two", attribute_value_binary);
729
730 let msg_with_attributes_body: &String =
731 msg_with_attributes.get_message_body().as_ref().unwrap();
732
733 let result_2: MessageSize = sqs_extended_client.message_size(
734 msg_with_attributes_body,
735 msg_with_attributes.get_message_attributes(),
736 );
737
738 assert!(result_2.total() == 79);
739 }
740
741 #[test]
742 fn test_calc_attribute_size() {
743 let sqs_extended_client: SqsExtendedClient =
744 SqsExtendedClientBuilder::new(make_test_s3_client()).build();
745
746 let reserved_attribute = MessageAttributeValue::builder()
747 .data_type(String::from("String"))
748 .string_value(String::from("some string"))
749 .build()
750 .expect("build MessageAttrbuteValue should not fail");
751
752 let mut hm: HashMap<String, MessageAttributeValue> = HashMap::new();
753 hm.insert(String::from("testing_strings"), reserved_attribute);
754
755 assert_eq!(32, sqs_extended_client.calc_attribute_size(&hm))
756 }
757
758 #[test]
759 fn test_s3_key() {
760 let sqs_extended_client_no_prefix: SqsExtendedClient =
761 SqsExtendedClientBuilder::new(make_test_s3_client()).build();
762
763 assert_eq!(
764 String::from("00000000-0000-0000-0000-000000000000"),
765 sqs_extended_client_no_prefix
766 .s3_key(String::from("00000000-0000-0000-0000-000000000000"))
767 );
768
769 let sqs_extended_client_with_prefix: SqsExtendedClient =
770 SqsExtendedClientBuilder::new(make_test_s3_client())
771 .with_object_prefix(String::from("some_prefix"))
772 .build();
773
774 assert_eq!(
775 String::from("some_prefix/00000000-0000-0000-0000-000000000000"),
776 sqs_extended_client_with_prefix
777 .s3_key(String::from("00000000-0000-0000-0000-000000000000"))
778 )
779 }
780
781 #[test]
782 fn test_new_extended_receipt_handle() {
783 let bucket: String = "BUCKET".to_string();
784 let key: String = "KEY".to_string();
785 let handle: String = "HANDLE".to_string();
786
787 let result: String = SqsExtendedClient::new_extended_receipt_handle(bucket, key, handle);
788
789 assert!(
790 result == "-..s3BucketName..-BUCKET-..s3BucketName..--..s3Key..-KEY-..s3Key..-HANDLE"
791 );
792 }
793
794 #[test]
795 fn test_parse_extended_receipt_handle() {
796 let sqs_extended_client: SqsExtendedClient =
797 SqsExtendedClientBuilder::new(make_test_s3_client()).build();
798
799 let extended_receipt_handler: String =
800 "-..s3BucketName..-BUCKET-..s3BucketName..--..s3Key..-KEY-..s3Key..-HANDLE".to_string();
801
802 let (bucket, key, receipt_handle) =
803 sqs_extended_client.parse_extended_receipt_handle(extended_receipt_handler);
804
805 assert!(bucket == "BUCKET");
806 assert!(key == "KEY");
807 assert!(receipt_handle == "HANDLE");
808 }
809
810 #[test]
811 fn test_marshall_json() {
812 let sqs_extended_client: SqsExtendedClient =
813 SqsExtendedClientBuilder::new(make_test_s3_client()).build();
814
815 let s3_pointer: S3Pointer = S3Pointer {
816 s3_bucket_name: "BUCKET".to_string(),
817 s3_key: "KEY".to_string(),
818 class: sqs_extended_client.pointer_class.clone(),
819 };
820
821 let json_s3_pointer: String = s3_pointer.marshall_json();
822
823 assert!(
824 json_s3_pointer
825 == r#"["software.amazon.payloadoffloading.PayloadS3Pointer",{"s3BucketName":"BUCKET","s3Key":"KEY"}]"#
826 );
827 }
828
829 #[test]
830 fn test_unmarshall_json() {
831 let s3_pointer_str: &str = r#"["software.amazon.payloadoffloading.PayloadS3Pointer",{"s3BucketName":"BUCKET","s3Key":"KEY"}]"#;
832
833 let s3_pointer_struct: S3Pointer =
834 S3Pointer::unmarshall_json(s3_pointer_str).expect("s3_pointer unmarshall failed");
835
836 assert!(s3_pointer_struct.s3_bucket_name == "BUCKET");
837 assert!(s3_pointer_struct.s3_key == "KEY");
838 assert!(
839 s3_pointer_struct.class
840 == SqsExtendedClientBuilder::new(make_test_s3_client())
841 .build()
842 .pointer_class
843 .clone()
844 );
845 }
846}