sqs_extended_client/
lib.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::str::Utf8Error;
4
5use aws_sdk_s3::operation::delete_object::DeleteObjectError;
6use aws_sdk_s3::operation::get_object::{GetObjectError, GetObjectOutput};
7use aws_sdk_s3::operation::put_object::{PutObjectError, PutObjectOutput};
8use aws_sdk_s3::primitives::ByteStreamError;
9use aws_sdk_sqs::operation::change_message_visibility::builders::ChangeMessageVisibilityFluentBuilder;
10use aws_sdk_sqs::operation::change_message_visibility::{
11    ChangeMessageVisibilityError, ChangeMessageVisibilityOutput,
12};
13use aws_sdk_sqs::operation::delete_message::builders::DeleteMessageFluentBuilder;
14use aws_sdk_sqs::operation::delete_message::{DeleteMessageError, DeleteMessageOutput};
15use aws_sdk_sqs::operation::receive_message::builders::ReceiveMessageFluentBuilder;
16use aws_sdk_sqs::operation::receive_message::{ReceiveMessageError, ReceiveMessageOutput};
17use aws_sdk_sqs::operation::send_message::builders::SendMessageFluentBuilder;
18use aws_sdk_sqs::operation::send_message::{SendMessageError, SendMessageOutput};
19use aws_sdk_sqs::types::Message;
20use aws_sdk_sqs::types::MessageAttributeValue;
21use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
22use aws_smithy_runtime_api::client::result::SdkError;
23use aws_smithy_runtime_api::http::Response;
24use aws_smithy_types::byte_stream::ByteStream;
25use aws_smithy_types::error::operation::BuildError;
26use regex::Regex;
27use serde::{Deserialize, Serialize};
28use serde_json::Result as SerdeJsonResult;
29use uuid::Uuid;
30
31const MAX_MESSAGE_SIZE_IN_BYTES: usize = 262144;
32static DEFAULT_POINTER_CLASS: &str = "software.amazon.payloadoffloading.PayloadS3Pointer";
33static LEGACY_RESERVED_ATTRIBUTE_NAME: &str = "SQSLargePayloadSize";
34
35//-SQS EXTENDED CLIENT BUILDER--------------------------------------------------
36
37pub struct SqsExtendedClientBuilder {
38    s3_client: aws_sdk_s3::Client,
39    bucket_name: Option<String>,
40    message_size_threshold: usize,
41    batch_message_size_threshold: usize,
42    always_s3: bool,
43    pointer_class: String,
44    reserved_attributes: Vec<String>,
45    object_prefix: String,
46}
47
48impl SqsExtendedClientBuilder {
49    pub fn new(s3_client: aws_sdk_s3::Client) -> SqsExtendedClientBuilder {
50        SqsExtendedClientBuilder {
51            s3_client,
52            bucket_name: None,
53            message_size_threshold: MAX_MESSAGE_SIZE_IN_BYTES,
54            batch_message_size_threshold: MAX_MESSAGE_SIZE_IN_BYTES,
55            always_s3: false,
56            pointer_class: DEFAULT_POINTER_CLASS.to_string(),
57            reserved_attributes: vec![
58                "ExtendedPayloadSize".to_string(),
59                LEGACY_RESERVED_ATTRIBUTE_NAME.to_string(),
60            ],
61            object_prefix: "".to_string(),
62        }
63    }
64
65    pub fn with_s3_bucket_name(mut self, bucket_name: String) -> SqsExtendedClientBuilder {
66        self.bucket_name = Some(bucket_name);
67        self
68    }
69
70    pub fn with_message_size_threshold(mut self, msg_size: usize) -> SqsExtendedClientBuilder {
71        self.message_size_threshold = msg_size;
72        self
73    }
74
75    pub fn with_batch_message_size_threshold(
76        mut self,
77        batch_msg_size: usize,
78    ) -> SqsExtendedClientBuilder {
79        self.batch_message_size_threshold = batch_msg_size;
80        self
81    }
82
83    pub fn with_always_through_s3(mut self, always_s3: bool) -> SqsExtendedClientBuilder {
84        self.always_s3 = always_s3;
85        self
86    }
87
88    pub fn with_reserved_attribute_names(
89        mut self,
90        reserved_attribute_names: Vec<String>,
91    ) -> SqsExtendedClientBuilder {
92        self.reserved_attributes = reserved_attribute_names;
93        self
94    }
95
96    pub fn with_pointer_class(mut self, pointer_class: String) -> SqsExtendedClientBuilder {
97        self.pointer_class = pointer_class;
98        self
99    }
100
101    pub fn with_object_prefix(mut self, prefix: String) -> SqsExtendedClientBuilder {
102        self.object_prefix = prefix;
103        self
104    }
105
106    pub fn build(self) -> SqsExtendedClient {
107        let receipt_handler_regex: Regex = Regex::new(r"^-\.\.s3BucketName\.\.-(.*)-\.\.s3BucketName\.\.--\.\.s3Key\.\.-(.*)-\.\.s3Key\.\.-(.*)").unwrap();
108
109        SqsExtendedClient {
110            s3_client: self.s3_client,
111            bucket_name: self.bucket_name,
112            message_size_threshold: self.message_size_threshold,
113            always_through_s3: self.always_s3,
114            pointer_class: self.pointer_class,
115            reserved_attributes: self.reserved_attributes,
116            object_prefix: self.object_prefix,
117            extended_receipt_handler_regex: receipt_handler_regex,
118        }
119    }
120}
121
122//-SQS EXTENDED CLIENT----------------------------------------------------------
123
124pub struct SqsExtendedClient {
125    s3_client: aws_sdk_s3::Client,
126    bucket_name: Option<String>,
127    message_size_threshold: usize,
128    always_through_s3: bool,
129    pointer_class: String,
130    reserved_attributes: Vec<String>,
131    object_prefix: String,
132    extended_receipt_handler_regex: Regex,
133}
134
135impl SqsExtendedClient {
136    pub async fn send_message(
137        &self,
138        msg_input: SendMessageFluentBuilder,
139    ) -> Result<SendMessageOutput, SqsExtendedClientError> {
140        let Some(bn) = &self.bucket_name else {
141            return Err(SqsExtendedClientError::NoBucketName);
142        };
143        let bucket_name: String = bn.to_string();
144
145        let Some(msg_bdy) = msg_input.get_message_body() else {
146            return Err(SqsExtendedClientError::NoMessageBody);
147        };
148        let message_body: &str = msg_bdy;
149
150        let result: Result<SendMessageOutput, SdkError<SendMessageError, Response>> = if self
151            .always_through_s3
152            || self.message_exceeds_threshold(message_body, msg_input.get_message_attributes())
153        {
154            let s3_key: String = self.s3_key(Uuid::new_v4().to_string());
155
156            let s3_result: Result<PutObjectOutput, SdkError<PutObjectError, HttpResponse>> = self
157                .s3_client
158                .put_object()
159                .bucket(&bucket_name)
160                .key(&s3_key)
161                .body(ByteStream::from(message_body.as_bytes().to_vec()))
162                .send()
163                .await;
164
165            if let Err(s3_error) = s3_result {
166                return Err(SqsExtendedClientError::S3Upload(s3_error));
167            }
168
169            let new_msg: S3Pointer = S3Pointer {
170                s3_bucket_name: bucket_name,
171                s3_key,
172                class: self.pointer_class.clone(),
173            };
174
175            let message_body_size: usize = message_body.len();
176
177            let reserved_attribute: MessageAttributeValue = MessageAttributeValue::builder()
178                .data_type("Number")
179                .string_value(message_body_size.to_string())
180                .build()?;
181
182            msg_input
183                .message_body(new_msg.marshall_json())
184                .message_attributes(self.reserved_attributes[0].clone(), reserved_attribute)
185                .send()
186                .await
187        } else {
188            msg_input.send().await
189        };
190
191        result.map_err(SqsExtendedClientError::SqsSendMessage)
192    }
193
194    pub async fn receive_message(
195        &self,
196        receive_message_builder: ReceiveMessageFluentBuilder,
197    ) -> Result<ReceiveMessageOutput, SqsExtendedClientError> {
198        let mut sqs_response: ReceiveMessageOutput = receive_message_builder
199            .message_attribute_names("All")
200            .send()
201            .await?;
202
203        let mut messages: Vec<Message> = match &sqs_response.messages {
204            None => return Ok(sqs_response),
205            Some(msgs) => msgs.clone(),
206        };
207
208        for msg in messages.iter_mut() {
209            let mut found: bool = false;
210
211            for rsrvd_attr in self.reserved_attributes.iter() {
212                let msg_attrs: HashMap<String, MessageAttributeValue> =
213                    match &msg.message_attributes {
214                        None => break,
215                        Some(ma) => ma.clone(),
216                    };
217
218                if msg_attrs.contains_key(rsrvd_attr.as_str()) {
219                    found = true;
220                    break;
221                }
222            }
223
224            if !found {
225                continue;
226            }
227
228            let body: String = match &msg.body {
229                None => return Err(SqsExtendedClientError::NoMessageBody),
230                Some(b) => b.to_string(),
231            };
232
233            let receipt_handle: String = match &msg.receipt_handle {
234                None => return Err(SqsExtendedClientError::NoReceiptHandle),
235                Some(rh) => rh.to_string(),
236            };
237
238            let s3_pointer = S3Pointer::unmarshall_json(&body)?;
239
240            let object: GetObjectOutput = self
241                .s3_client
242                .get_object()
243                .bucket(s3_pointer.s3_bucket_name.clone())
244                .key(s3_pointer.s3_key.clone())
245                .send()
246                .await?;
247
248            let bytes = object.body.collect().await?.into_bytes();
249            let response: &str = std::str::from_utf8(&bytes)?;
250
251            msg.body = Some(response.to_string());
252            msg.receipt_handle = Some(Self::new_extended_receipt_handle(
253                s3_pointer.s3_bucket_name.clone(),
254                s3_pointer.s3_key.clone(),
255                receipt_handle,
256            ))
257        }
258
259        sqs_response.messages = Some(messages);
260        Ok(sqs_response)
261    }
262
263    pub async fn delete_message(
264        &self,
265        mut delete_message_builder: DeleteMessageFluentBuilder,
266    ) -> Result<DeleteMessageOutput, SqsExtendedClientError> {
267        let receipt_handle: String = match delete_message_builder.get_receipt_handle() {
268            None => return Err(SqsExtendedClientError::NoReceiptHandle),
269            Some(rh) => rh.to_string(),
270        };
271
272        let (bucket, key, handle) = self.parse_extended_receipt_handle(receipt_handle.clone());
273
274        if !bucket.is_empty() && !key.is_empty() && !handle.is_empty() {
275            delete_message_builder =
276                delete_message_builder.set_receipt_handle(Some(handle.clone()));
277        }
278
279        let resp: DeleteMessageOutput = delete_message_builder.send().await?;
280
281        if !bucket.is_empty() && !key.is_empty() {
282            self.s3_client
283                .delete_object()
284                .bucket(bucket)
285                .key(key)
286                .send()
287                .await?;
288        }
289
290        Ok(resp)
291    }
292
293    pub async fn change_message_visibility(
294        &self,
295        mut change_message_visibility: ChangeMessageVisibilityFluentBuilder,
296    ) -> Result<ChangeMessageVisibilityOutput, SqsExtendedClientError> {
297        let receipt_handle: String = match change_message_visibility.get_receipt_handle() {
298            None => return Err(SqsExtendedClientError::NoReceiptHandle),
299            Some(rh) => rh.to_string(),
300        };
301
302        let (bucket, key, handle) = self.parse_extended_receipt_handle(receipt_handle.clone());
303
304        if !bucket.is_empty() && !key.is_empty() && !handle.is_empty() {
305            change_message_visibility =
306                change_message_visibility.set_receipt_handle(Some(handle.clone()));
307        }
308
309        let resp: ChangeMessageVisibilityOutput = change_message_visibility.send().await?;
310
311        Ok(resp)
312    }
313
314    fn message_exceeds_threshold(
315        &self,
316        body: &str,
317        attributes: &Option<HashMap<String, MessageAttributeValue>>,
318    ) -> bool {
319        self.message_size(body, attributes).total() > self.message_size_threshold
320    }
321
322    fn message_size(
323        &self,
324        body: &str,
325        attributes: &Option<HashMap<String, MessageAttributeValue>>,
326    ) -> MessageSize {
327        MessageSize {
328            body_size: body.len(),
329            attribute_size: self.attribute_size(attributes),
330        }
331    }
332
333    fn attribute_size(&self, attributes: &Option<HashMap<String, MessageAttributeValue>>) -> usize {
334        match attributes {
335            None => 0,
336            Some(hash_map) => self.calc_attribute_size(hash_map),
337        }
338    }
339
340    fn calc_attribute_size(&self, attributes: &HashMap<String, MessageAttributeValue>) -> usize {
341        let mut sum: usize = 0;
342        for (k, v) in attributes {
343            sum += k.len();
344
345            match &v.binary_value {
346                None => {}
347                Some(blob) => {
348                    sum += blob.as_ref().len();
349                }
350            }
351            match &v.string_value {
352                None => {}
353                Some(string) => sum += string.len(),
354            }
355            sum += v.data_type.len();
356        }
357        sum
358    }
359
360    fn s3_key(&self, filename: String) -> String {
361        if !self.object_prefix.is_empty() {
362            return format!("{}/{}", self.object_prefix, filename);
363        }
364        filename
365    }
366
367    fn new_extended_receipt_handle(bucket: String, key: String, handle: String) -> String {
368        let s3_bucket_name_marker: String = "-..s3BucketName..-".to_string();
369        let s3_key_marker: String = "-..s3Key..-".to_string();
370
371        format!(
372            "{}{}{}{}{}{}{}",
373            s3_bucket_name_marker,
374            bucket,
375            s3_bucket_name_marker,
376            s3_key_marker,
377            key,
378            s3_key_marker,
379            handle
380        )
381        .to_string()
382    }
383
384    fn parse_extended_receipt_handle(
385        &self,
386        extended_receipt_handle: String,
387    ) -> (String, String, String) {
388        let caps = match self
389            .extended_receipt_handler_regex
390            .captures(&extended_receipt_handle)
391        {
392            Some(caps) => caps,
393            None => return ("".to_string(), "".to_string(), "".to_string()),
394        };
395
396        if caps.len() != 4 {
397            return ("".to_string(), "".to_string(), "".to_string());
398        }
399
400        let bucket = caps
401            .get(1)
402            .map(|m| m.as_str().to_string())
403            .unwrap_or_default();
404        let key = caps
405            .get(2)
406            .map(|m| m.as_str().to_string())
407            .unwrap_or_default();
408        let receipt_handle = caps
409            .get(3)
410            .map(|m| m.as_str().to_string())
411            .unwrap_or_default();
412
413        (bucket, key, receipt_handle)
414    }
415}
416
417//-S3 POINTER-------------------------------------------------------------------
418
419#[derive(Serialize, Deserialize, Debug)]
420struct S3PointerArray(String, S3PointerBucketAndKeyObject);
421
422#[derive(Serialize, Deserialize, Debug)]
423struct S3PointerBucketAndKeyObject {
424    #[serde(rename = "s3BucketName")]
425    s3_bucket_name: String,
426    #[serde(rename = "s3Key")]
427    s3_key: String,
428}
429
430#[derive(Serialize, Deserialize, Debug)]
431struct S3Pointer {
432    s3_bucket_name: String,
433    s3_key: String,
434    class: String,
435}
436
437impl S3Pointer {
438    fn marshall_json(self) -> String {
439        format!(
440            "[\"{}\",{{\"s3BucketName\":\"{}\",\"s3Key\":\"{}\"}}]",
441            self.class, self.s3_bucket_name, self.s3_key
442        )
443    }
444
445    fn unmarshall_json(input: &str) -> SerdeJsonResult<S3Pointer> {
446        let wrapper: S3PointerArray = serde_json::from_str(input)?;
447
448        let s3_pointer: S3Pointer = S3Pointer {
449            s3_bucket_name: wrapper.1.s3_bucket_name,
450            s3_key: wrapper.1.s3_key,
451            class: wrapper.0,
452        };
453
454        Ok(s3_pointer)
455    }
456}
457
458//-MESSAGE SIZE-----------------------------------------------------------------
459
460struct MessageSize {
461    body_size: usize,
462    attribute_size: usize,
463}
464
465impl MessageSize {
466    fn total(self) -> usize {
467        self.body_size + self.attribute_size
468    }
469}
470
471//-ERRORS-----------------------------------------------------------------------
472
473#[derive(Debug)]
474pub enum SqsExtendedClientError {
475    S3Upload(SdkError<PutObjectError, HttpResponse>),
476    S3Download(SdkError<GetObjectError, Response>),
477    S3DeleteObject(SdkError<DeleteObjectError, Response>),
478    S3DownloadToBytes(ByteStreamError),
479    S3DownloadToUtf8(Utf8Error),
480    SqsSendMessage(SdkError<SendMessageError, HttpResponse>),
481    SqsReceiveMessage(SdkError<ReceiveMessageError, HttpResponse>),
482    SqsDeleteMessage(SdkError<DeleteMessageError, Response>),
483    SqsChangeMessageVisibility(SdkError<ChangeMessageVisibilityError, Response>),
484    SqsBuildMessageAttribute(BuildError),
485    SqsReceiveMessageUnMarshallMessageBody(serde_json::Error),
486    NoBucketName,
487    NoMessageBody,
488    NoReceiptHandle,
489}
490
491impl fmt::Display for SqsExtendedClientError {
492    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
493        match self {
494            Self::S3Upload(err) => write!(f, "S3 upload failed: {}", err),
495            Self::S3Download(err) => write!(f, "S3 download failed: {}", err),
496            Self::S3DeleteObject(err) => write!(f, "S3 delete failed: {}", err),
497            Self::S3DownloadToBytes(err) => write!(f, "S3 Byte Stream Error: {}", err),
498            Self::S3DownloadToUtf8(err) => write!(f, "S3 Byte Stream Error: {}", err),
499            Self::SqsSendMessage(err) => write!(f, "SQS operation failed: {}", err),
500            Self::SqsReceiveMessage(err) => write!(f, "SQS operation failed: {}", err),
501            Self::SqsDeleteMessage(err) => write!(f, "SQS delete failed: {}", err),
502            Self::SqsChangeMessageVisibility(err) => {
503                write!(f, "SQS change message visibilty failed: {}", err)
504            }
505            Self::SqsBuildMessageAttribute(err) => {
506                write!(f, "SQS build message attribute failed: {}", err)
507            }
508            Self::SqsReceiveMessageUnMarshallMessageBody(err) => {
509                write!(f, "Failed to marshall sqs message body: {}", err)
510            }
511            Self::NoBucketName => write!(f, "No bucket name configured"),
512            Self::NoMessageBody => write!(f, "No message body"),
513            Self::NoReceiptHandle => write!(f, "No receipt handle"),
514        }
515    }
516}
517
518impl From<aws_sdk_s3::error::BuildError> for SqsExtendedClientError {
519    fn from(err: aws_sdk_s3::error::BuildError) -> Self {
520        Self::SqsBuildMessageAttribute(err)
521    }
522}
523
524impl From<Utf8Error> for SqsExtendedClientError {
525    fn from(err: Utf8Error) -> Self {
526        Self::S3DownloadToUtf8(err)
527    }
528}
529
530impl From<ByteStreamError> for SqsExtendedClientError {
531    fn from(err: ByteStreamError) -> Self {
532        Self::S3DownloadToBytes(err)
533    }
534}
535
536impl From<SdkError<GetObjectError, Response>> for SqsExtendedClientError {
537    fn from(err: SdkError<GetObjectError, Response>) -> Self {
538        Self::S3Download(err)
539    }
540}
541
542impl From<SdkError<DeleteObjectError, Response>> for SqsExtendedClientError {
543    fn from(err: SdkError<DeleteObjectError, Response>) -> Self {
544        Self::S3DeleteObject(err)
545    }
546}
547
548impl From<SdkError<ReceiveMessageError, HttpResponse>> for SqsExtendedClientError {
549    fn from(err: SdkError<ReceiveMessageError, HttpResponse>) -> Self {
550        Self::SqsReceiveMessage(err)
551    }
552}
553
554impl From<SdkError<DeleteMessageError, Response>> for SqsExtendedClientError {
555    fn from(err: SdkError<DeleteMessageError, Response>) -> Self {
556        Self::SqsDeleteMessage(err)
557    }
558}
559
560impl From<SdkError<ChangeMessageVisibilityError, Response>> for SqsExtendedClientError {
561    fn from(err: SdkError<ChangeMessageVisibilityError, Response>) -> Self {
562        Self::SqsChangeMessageVisibility(err)
563    }
564}
565
566impl From<serde_json::Error> for SqsExtendedClientError {
567    fn from(err: serde_json::Error) -> Self {
568        Self::SqsReceiveMessageUnMarshallMessageBody(err)
569    }
570}
571
572impl std::error::Error for SqsExtendedClientError {}
573
574//-TESTS------------------------------------------------------------------------
575
576#[cfg(test)]
577mod tests {
578    use aws_config::BehaviorVersion;
579    use aws_smithy_types::Blob;
580
581    use super::*;
582
583    fn make_test_credentials() -> aws_sdk_s3::config::Credentials {
584        aws_sdk_s3::config::Credentials::new(
585            "TEST_ACCESS_KEY_ID",
586            "TEST_SECRET_ACCESS_KEY",
587            Some("TEST_SESSION_TOKEN".to_string()),
588            None,
589            "",
590        )
591    }
592
593    fn make_test_s3_client() -> aws_sdk_s3::client::Client {
594        aws_sdk_s3::Client::from_conf(
595            aws_sdk_s3::Config::builder()
596                .behavior_version(BehaviorVersion::latest())
597                .credentials_provider(make_test_credentials())
598                .build(),
599        )
600    }
601
602    fn make_test_sqs_client() -> aws_sdk_sqs::client::Client {
603        aws_sdk_sqs::Client::from_conf(
604            aws_sdk_sqs::Config::builder()
605                .behavior_version(BehaviorVersion::latest())
606                .credentials_provider(make_test_credentials())
607                .build(),
608        )
609    }
610
611    #[test]
612    fn test_builder_fns() {
613        let sqs_extended_client: SqsExtendedClient =
614            SqsExtendedClientBuilder::new(make_test_s3_client())
615                .with_s3_bucket_name("bucket-name".to_string())
616                .with_message_size_threshold(9999)
617                .with_batch_message_size_threshold(1000)
618                .with_always_through_s3(true)
619                .with_reserved_attribute_names(vec!["attr_one".to_string(), "attr_two".to_string()])
620                .with_pointer_class("pointer-class".to_string())
621                .with_object_prefix("object-prefix".to_string())
622                .build();
623
624        let bucket_name: String = sqs_extended_client.bucket_name.unwrap_or_default();
625        assert_eq!("bucket-name", bucket_name);
626        assert_eq!(9999, sqs_extended_client.message_size_threshold);
627        assert!(sqs_extended_client.always_through_s3);
628        assert_eq!(
629            vec!["attr_one".to_string(), "attr_two".to_string()],
630            sqs_extended_client.reserved_attributes
631        );
632        assert_eq!("pointer-class", sqs_extended_client.pointer_class);
633        assert_eq!("object-prefix", sqs_extended_client.object_prefix);
634    }
635
636    #[test]
637    fn test_builder_defaults() {
638        let sqs_extended_client: SqsExtendedClient =
639            SqsExtendedClientBuilder::new(make_test_s3_client()).build();
640
641        let bucket_name: String = sqs_extended_client.bucket_name.unwrap_or_default();
642
643        assert_eq!("", bucket_name);
644        assert_eq!(
645            MAX_MESSAGE_SIZE_IN_BYTES,
646            sqs_extended_client.message_size_threshold
647        );
648        assert!(!sqs_extended_client.always_through_s3);
649        assert_eq!(
650            vec![
651                "ExtendedPayloadSize".to_string(),
652                LEGACY_RESERVED_ATTRIBUTE_NAME.to_string(),
653            ],
654            sqs_extended_client.reserved_attributes
655        );
656        assert_eq!(DEFAULT_POINTER_CLASS, sqs_extended_client.pointer_class);
657        assert_eq!("", sqs_extended_client.object_prefix);
658    }
659
660    // send_message
661    // receive_message
662    // delete_message
663    // change_message_visibility
664
665    #[test]
666    fn test_message_exceeds_threshold() {
667        // need table tests -> macros in rust apparently ...
668        let sqs_extended_client_small_msg_size_threshold: SqsExtendedClient =
669            SqsExtendedClientBuilder::new(make_test_s3_client())
670                .with_message_size_threshold(8)
671                .build();
672
673        let msg: SendMessageFluentBuilder = make_test_sqs_client()
674            .send_message()
675            .queue_url("queue_url")
676            .message_body("hello world");
677
678        let msg_body: &String = msg.get_message_body().as_ref().unwrap();
679
680        let result: bool = sqs_extended_client_small_msg_size_threshold
681            .message_exceeds_threshold(msg_body, msg.get_message_attributes());
682
683        assert!(result);
684
685        let sqs_extended_client_default: SqsExtendedClient =
686            SqsExtendedClientBuilder::new(make_test_s3_client()).build();
687
688        let result_2: bool = sqs_extended_client_default
689            .message_exceeds_threshold(msg_body, msg.get_message_attributes());
690
691        assert!(!result_2);
692    }
693
694    #[test]
695    fn test_message_size() {
696        let sqs_extended_client: SqsExtendedClient =
697            SqsExtendedClientBuilder::new(make_test_s3_client()).build();
698
699        let msg: SendMessageFluentBuilder = make_test_sqs_client()
700            .send_message()
701            .queue_url("queue_url")
702            .message_body("hello world");
703
704        let msg_body: &String = msg.get_message_body().as_ref().unwrap();
705
706        let result: MessageSize =
707            sqs_extended_client.message_size(msg_body, msg.get_message_attributes());
708
709        assert!(result.total() == 11);
710
711        let attribute_value_number: MessageAttributeValue = MessageAttributeValue::builder()
712            .data_type("Number")
713            .string_value("12")
714            .build()
715            .expect("Failed to build MessageAttributeValue");
716
717        let attribute_value_binary: MessageAttributeValue = MessageAttributeValue::builder()
718            .data_type("Binary")
719            .binary_value(Blob::new("IT'S BINARY HONEST"))
720            .build()
721            .expect("Failed to build MessageAttributeValue");
722
723        let msg_with_attributes: SendMessageFluentBuilder = make_test_sqs_client()
724            .send_message()
725            .queue_url("queue_url")
726            .message_body("hello world")
727            .message_attributes("test_attribute_one", attribute_value_number)
728            .message_attributes("test_attribute_two", attribute_value_binary);
729
730        let msg_with_attributes_body: &String =
731            msg_with_attributes.get_message_body().as_ref().unwrap();
732
733        let result_2: MessageSize = sqs_extended_client.message_size(
734            msg_with_attributes_body,
735            msg_with_attributes.get_message_attributes(),
736        );
737
738        assert!(result_2.total() == 79);
739    }
740
741    #[test]
742    fn test_calc_attribute_size() {
743        let sqs_extended_client: SqsExtendedClient =
744            SqsExtendedClientBuilder::new(make_test_s3_client()).build();
745
746        let reserved_attribute = MessageAttributeValue::builder()
747            .data_type(String::from("String"))
748            .string_value(String::from("some string"))
749            .build()
750            .expect("build MessageAttrbuteValue should not fail");
751
752        let mut hm: HashMap<String, MessageAttributeValue> = HashMap::new();
753        hm.insert(String::from("testing_strings"), reserved_attribute);
754
755        assert_eq!(32, sqs_extended_client.calc_attribute_size(&hm))
756    }
757
758    #[test]
759    fn test_s3_key() {
760        let sqs_extended_client_no_prefix: SqsExtendedClient =
761            SqsExtendedClientBuilder::new(make_test_s3_client()).build();
762
763        assert_eq!(
764            String::from("00000000-0000-0000-0000-000000000000"),
765            sqs_extended_client_no_prefix
766                .s3_key(String::from("00000000-0000-0000-0000-000000000000"))
767        );
768
769        let sqs_extended_client_with_prefix: SqsExtendedClient =
770            SqsExtendedClientBuilder::new(make_test_s3_client())
771                .with_object_prefix(String::from("some_prefix"))
772                .build();
773
774        assert_eq!(
775            String::from("some_prefix/00000000-0000-0000-0000-000000000000"),
776            sqs_extended_client_with_prefix
777                .s3_key(String::from("00000000-0000-0000-0000-000000000000"))
778        )
779    }
780
781    #[test]
782    fn test_new_extended_receipt_handle() {
783        let bucket: String = "BUCKET".to_string();
784        let key: String = "KEY".to_string();
785        let handle: String = "HANDLE".to_string();
786
787        let result: String = SqsExtendedClient::new_extended_receipt_handle(bucket, key, handle);
788
789        assert!(
790            result == "-..s3BucketName..-BUCKET-..s3BucketName..--..s3Key..-KEY-..s3Key..-HANDLE"
791        );
792    }
793
794    #[test]
795    fn test_parse_extended_receipt_handle() {
796        let sqs_extended_client: SqsExtendedClient =
797            SqsExtendedClientBuilder::new(make_test_s3_client()).build();
798
799        let extended_receipt_handler: String =
800            "-..s3BucketName..-BUCKET-..s3BucketName..--..s3Key..-KEY-..s3Key..-HANDLE".to_string();
801
802        let (bucket, key, receipt_handle) =
803            sqs_extended_client.parse_extended_receipt_handle(extended_receipt_handler);
804
805        assert!(bucket == "BUCKET");
806        assert!(key == "KEY");
807        assert!(receipt_handle == "HANDLE");
808    }
809
810    #[test]
811    fn test_marshall_json() {
812        let sqs_extended_client: SqsExtendedClient =
813            SqsExtendedClientBuilder::new(make_test_s3_client()).build();
814
815        let s3_pointer: S3Pointer = S3Pointer {
816            s3_bucket_name: "BUCKET".to_string(),
817            s3_key: "KEY".to_string(),
818            class: sqs_extended_client.pointer_class.clone(),
819        };
820
821        let json_s3_pointer: String = s3_pointer.marshall_json();
822
823        assert!(
824            json_s3_pointer
825                == r#"["software.amazon.payloadoffloading.PayloadS3Pointer",{"s3BucketName":"BUCKET","s3Key":"KEY"}]"#
826        );
827    }
828
829    #[test]
830    fn test_unmarshall_json() {
831        let s3_pointer_str: &str = r#"["software.amazon.payloadoffloading.PayloadS3Pointer",{"s3BucketName":"BUCKET","s3Key":"KEY"}]"#;
832
833        let s3_pointer_struct: S3Pointer =
834            S3Pointer::unmarshall_json(s3_pointer_str).expect("s3_pointer unmarshall failed");
835
836        assert!(s3_pointer_struct.s3_bucket_name == "BUCKET");
837        assert!(s3_pointer_struct.s3_key == "KEY");
838        assert!(
839            s3_pointer_struct.class
840                == SqsExtendedClientBuilder::new(make_test_s3_client())
841                    .build()
842                    .pointer_class
843                    .clone()
844        );
845    }
846}