use crate::error::{Error, Result};
use futures01::future::Future;
use lazy_static::lazy_static;
use log::{debug, error};
use md5::Digest;
use rusoto_core::{
credential::ProvideAwsCredentials, request::HttpClient, DispatchSignedRequest, Region,
RusotoError,
};
use rusoto_credential::DefaultCredentialsProvider;
use rusoto_mock::{MockResponseReader, ReadMockResponse};
use rusoto_sqs::{
GetQueueUrlError, GetQueueUrlRequest, GetQueueUrlResult, MessageAttributeValue,
SendMessageError, SendMessageRequest, SendMessageResult, Sqs, SqsClient,
};
use serde::Serialize;
use std::{
collections::{BTreeMap, HashMap},
fmt,
};
use tokio_core::reactor::Core;
use uuid::Uuid;
pub struct Pub {
core: Core,
client: SqsClient,
queue_name: String,
}
impl fmt::Debug for Pub {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Pub -> Queue ({})", self.queue_name)
}
}
impl Pub {
pub fn initialize<T>(region: Region, queue_name: T) -> Result<Self>
where
T: Into<String>,
{
let provider = DefaultCredentialsProvider::new()?;
let rusoto_client = HttpClient::from_connector(super::new_connector()?);
Self::initialize_internal(region, queue_name, provider, rusoto_client)
}
#[doc(hidden)]
pub fn initialize_internal<P, R, T>(
region: Region,
queue_name: T,
provider: P,
request_dispatcher: R,
) -> Result<Self>
where
P: ProvideAwsCredentials + Send + Sync + 'static,
P::Future: Send,
R: DispatchSignedRequest + Send + Sync + 'static,
R::Future: Send,
T: Into<String>,
{
let client = SqsClient::new_with(request_dispatcher, provider, region);
let core = Core::new()?;
Ok(Self {
core,
client,
queue_name: queue_name.into(),
})
}
pub fn publish<M>(
&mut self,
message: M,
attributes: Option<HashMap<String, String>>,
) -> Result<Uuid>
where
M: Serialize,
{
let message_json = serde_json::to_string(&message)?;
let message_attributes = convert_attributes(attributes);
let digests = Digests {
message_body: Some(md5::compute(message_json.as_bytes())),
message_attributes: message_attributes.clone().and_then(|ma| ma_md5(&ma)),
message_system_attributes: None,
};
let client_clone = self.client.clone();
let publish_future = self
.client
.get_queue_url(super::get_queue_url_request(self.queue_name.clone()))
.map_err(map_re)
.and_then(|result| {
client_clone.send_message(send_message_request(
message_json,
message_attributes,
result.queue_url.unwrap(),
))
});
match self.core.run(publish_future) {
Ok(publish_result) => {
if check_digests(&publish_result, &digests) {
Ok(Uuid::parse_str(
&publish_result
.message_id
.ok_or_else(Error::invalid_message_id)?,
)?)
} else {
Err(Error::digest_error())
}
}
Err(e) => {
error!("{}", e);
Err(Error::publish_error())
}
}
}
#[doc(hidden)]
#[must_use]
pub fn mock_responses() -> HashMap<String, Vec<u8>> {
lazy_static! {
static ref BODY_MAP: HashMap<String, Vec<u8>> = {
let mut body_map = HashMap::new();
let _ = body_map.insert(
"GetQueueUrl",
MockResponseReader::read_response("test-data", "get_queue_url_response.xml"),
);
let _ = body_map.insert(
"SendMessage",
MockResponseReader::read_response("test-data", "send_message_response.xml"),
);
body_map
.iter()
.map(|(k, v)| ((*k).to_string(), v.as_bytes().to_vec()))
.collect()
};
}
BODY_MAP.clone()
}
}
#[derive(Clone)]
pub struct PubFut {
client: SqsClient,
queue_name: String,
}
impl fmt::Debug for PubFut {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PubFut -> Queue ({})", self.queue_name)
}
}
impl PubFut {
pub fn initialize<T>(region: Region, queue_name: T) -> Result<Self>
where
T: Into<String>,
{
let provider = DefaultCredentialsProvider::new()?;
let rusoto_client = HttpClient::from_connector(super::new_connector()?);
Self::initialize_internal(region, queue_name, provider, rusoto_client)
}
#[doc(hidden)]
pub fn initialize_internal<P, R, T>(
region: Region,
queue_name: T,
provider: P,
request_dispatcher: R,
) -> Result<Self>
where
P: ProvideAwsCredentials + Send + Sync + 'static,
P::Future: Send,
R: DispatchSignedRequest + Send + Sync + 'static,
R::Future: Send,
T: Into<String>,
{
let client = SqsClient::new_with(request_dispatcher, provider, region);
Ok(Self {
client,
queue_name: queue_name.into(),
})
}
pub fn publish_fut<M>(
&mut self,
message: M,
attributes: Option<HashMap<String, String>>,
) -> impl Future<Item = SendMessageResult, Error = Error>
where
M: Serialize,
{
let message_json = serde_json::to_string(&message).unwrap();
let message_attributes = convert_attributes(attributes);
let client_clone = self.client.clone();
let get_queue_url = PubFut::make_get_queue_url_future(
&self.client,
super::get_queue_url_request(self.queue_name.clone()),
);
let send_message = move |result: GetQueueUrlResult| {
client_clone.send_message(send_message_request(
message_json,
message_attributes,
result.queue_url.unwrap(),
))
};
get_queue_url
.map_err(map_re)
.and_then(send_message)
.map_err(map_e)
}
fn make_get_queue_url_future(
client: &SqsClient,
request: GetQueueUrlRequest,
) -> impl Future<Item = GetQueueUrlResult, Error = RusotoError<GetQueueUrlError>> {
client.get_queue_url(request)
}
#[doc(hidden)]
#[must_use]
pub fn mock_responses() -> HashMap<String, Vec<u8>> {
lazy_static! {
static ref BODY_MAP: HashMap<String, Vec<u8>> = {
let mut body_map = HashMap::new();
let _ = body_map.insert(
"GetQueueUrl",
MockResponseReader::read_response("test-data", "get_queue_url_response.xml"),
);
let _ = body_map.insert(
"SendMessage",
MockResponseReader::read_response("test-data", "send_message_response.xml"),
);
body_map
.iter()
.map(|(k, v)| ((*k).to_string(), v.as_bytes().to_vec()))
.collect()
};
}
BODY_MAP.clone()
}
}
#[allow(clippy::needless_pass_by_value)]
fn map_re(e: RusotoError<GetQueueUrlError>) -> RusotoError<SendMessageError> {
RusotoError::Service(SendMessageError::UnsupportedOperation(e.to_string()))
}
fn map_e(e: RusotoError<SendMessageError>) -> Error {
e.into()
}
fn send_message_request(
message_body: String,
message_attributes: Option<HashMap<String, MessageAttributeValue>>,
queue_url: String,
) -> SendMessageRequest {
SendMessageRequest {
message_body,
message_attributes,
queue_url,
..SendMessageRequest::default()
}
}
struct Digests {
message_body: Option<Digest>,
message_attributes: Option<Digest>,
message_system_attributes: Option<Digest>,
}
fn check_digests(message_result: &SendMessageResult, digests: &Digests) -> bool {
check_digest(&message_result.md5_of_message_body, digests.message_body)
&& check_digest(
&message_result.md5_of_message_attributes,
digests.message_attributes,
)
&& check_digest(
&message_result.md5_of_message_system_attributes,
digests.message_system_attributes,
)
}
fn check_digest(theirs: &Option<String>, mine: Option<Digest>) -> bool {
if let (Some(their), Some(mine)) = (theirs, mine) {
debug!("Mine: {:x}, Theirs: {}", mine, their);
&format!("{:x}", mine) == their
} else {
true
}
}
fn convert_attributes(
attributes: Option<HashMap<String, String>>,
) -> Option<HashMap<String, MessageAttributeValue>> {
attributes.map(|h| {
h.into_iter()
.map(|(k, v)| {
(
k,
MessageAttributeValue {
binary_list_values: None,
binary_value: None,
data_type: "String".to_string(),
string_list_values: None,
string_value: Some(v),
},
)
})
.collect::<HashMap<String, MessageAttributeValue>>()
})
}
fn ma_md5(attributes: &HashMap<String, MessageAttributeValue>) -> Option<Digest> {
let mut sorted_attributes = BTreeMap::new();
let mut buffer: Vec<u8> = vec![];
for (k, v) in attributes {
let _ = sorted_attributes.insert(k.clone(), v.clone());
}
for (k, v) in sorted_attributes {
encode_bytes(&mut buffer, k.as_bytes());
encode_bytes(&mut buffer, v.data_type.as_bytes());
if v.data_type == "String" {
buffer.push(1);
if let Some(string_value) = v.string_value {
encode_bytes(&mut buffer, string_value.as_bytes());
}
} else {
buffer.push(2);
if let Some(bytes_value) = v.binary_value {
encode_bytes(&mut buffer, &bytes_value[..]);
}
}
}
Some(md5::compute(&buffer))
}
fn encode_bytes(buffer: &mut Vec<u8>, value: &[u8]) {
use std::convert::TryFrom;
let length = u32::try_from(value.len()).unwrap_or(0);
buffer.extend(length.to_be_bytes().iter());
buffer.extend(value.iter());
}
#[cfg(test)]
mod test {
use super::{convert_attributes, ma_md5, Pub, PubFut};
use crate::{error::Result, utils::MockRequestDispatcher};
use futures::compat::Future01CompatExt;
use rusoto_core::Region;
use rusoto_mock::MockCredentialsProvider;
use std::collections::HashMap;
#[test]
fn publish() -> Result<()> {
let _ = pretty_env_logger::try_init_timed();
let mut publisher = Pub::initialize_internal(
Region::UsEast2,
"Orders",
MockCredentialsProvider,
MockRequestDispatcher::default().with_body_map(Pub::mock_responses()),
)?;
assert!(publisher.publish("test_message", None).is_ok());
Ok(())
}
#[tokio::test]
async fn publish_fut() -> Result<()> {
let _ = pretty_env_logger::try_init_timed();
let mut publisher = PubFut::initialize_internal(
Region::UsEast2,
"Orders",
MockCredentialsProvider,
MockRequestDispatcher::default().with_body_map(PubFut::mock_responses()),
)?;
let result = publisher.publish_fut("test_message", None).compat().await?;
assert_eq!(
result.message_id,
Some("5fea7756-0ea4-451a-a703-a558b933e274".to_string())
);
Ok(())
}
#[test]
fn md5_matches() -> Result<()> {
let mut attributes = HashMap::new();
let _ = attributes.insert("kind", "create");
let _ = attributes.insert("alpha", "beta");
let attr = attributes
.iter()
.map(|(k, v)| ((*k).to_string(), (*v).to_string()))
.collect();
if let Some(converted) = convert_attributes(Some(attr)) {
if let Some(digest) = ma_md5(&converted) {
assert_eq!(format!("{:x}", digest), "16adff48f45c49385c23f8603fa85845");
Ok(())
} else {
Err("invalid digest".into())
}
} else {
Err("invalid digest".into())
}
}
}