1use async_trait::async_trait;
37use aws_config::{BehaviorVersion, Region};
38use aws_sdk_sns::{config::Credentials, Client as SnsClient, Config as SnsConfig};
39use serde::{Deserialize, Serialize};
40use sms_core::*;
41use std::collections::HashMap;
42use tracing::{debug, error, info, warn};
43
44#[derive(Debug, Clone)]
57pub struct AwsSnsClient {
58 client: SnsClient,
59 region: String,
60}
61
62#[derive(Debug, Deserialize, Serialize)]
65pub struct SnsDeliveryNotification {
66 #[serde(rename = "Type")]
68 pub notification_type: String,
69 #[serde(rename = "MessageId")]
71 pub message_id: String,
72 #[serde(rename = "TopicArn")]
74 pub topic_arn: String,
75 #[serde(rename = "Message")]
77 pub message: String,
78 #[serde(rename = "Timestamp")]
80 pub timestamp: String,
81 #[serde(rename = "SignatureVersion")]
83 pub signature_version: String,
84 #[serde(rename = "Signature")]
86 pub signature: String,
87 #[serde(rename = "SigningCertURL")]
89 pub signing_cert_url: String,
90}
91
92#[derive(Debug, Deserialize, Serialize)]
95pub struct SmsDeliveryReport {
96 pub notification: SmsNotificationData,
98 pub delivery: SmsDeliveryData,
100 pub status: String,
102 #[serde(rename = "messageId")]
104 pub message_id: String,
105 #[serde(rename = "destinationPhoneNumber")]
107 pub destination_phone_number: String,
108}
109
110#[derive(Debug, Deserialize, Serialize)]
112pub struct SmsNotificationData {
113 #[serde(rename = "messageId")]
115 pub message_id: String,
116 pub timestamp: String,
118}
119
120#[derive(Debug, Deserialize, Serialize)]
122pub struct SmsDeliveryData {
123 pub destination: String,
125 #[serde(rename = "priceInUSD")]
127 pub price_in_usd: Option<f64>,
128 #[serde(rename = "smsType")]
130 pub sms_type: String,
131 #[serde(rename = "dwellTimeMs")]
133 pub dwell_time_ms: Option<u64>,
134 #[serde(rename = "dwellTimeMsUntilDeviceAck")]
136 pub dwell_time_ms_until_device_ack: Option<u64>,
137}
138
139impl AwsSnsClient {
140 pub fn new(
148 region: impl Into<String>,
149 access_key_id: impl Into<String>,
150 secret_access_key: impl Into<String>,
151 ) -> Self {
152 let region_str = region.into();
153 let region_copy = region_str.clone();
154 let aws_region = Region::from_static(Box::leak(region_copy.into_boxed_str()));
155
156 let credentials = Credentials::new(
157 access_key_id,
158 secret_access_key,
159 None,
160 None,
161 "smskit",
162 );
163
164 let config = SnsConfig::builder()
165 .region(aws_region)
166 .credentials_provider(credentials)
167 .behavior_version(BehaviorVersion::latest())
168 .build();
169
170 let client = SnsClient::from_conf(config);
171
172 Self {
173 client,
174 region: region_str,
175 }
176 }
177
178 pub fn from_env() -> Result<Self, SmsError> {
188 let region = std::env::var("AWS_REGION")
189 .or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
190 .map_err(|_| SmsError::Auth("AWS_REGION (or AWS_DEFAULT_REGION) not set".into()))?;
191 let access_key_id = std::env::var("AWS_ACCESS_KEY_ID")
192 .map_err(|_| SmsError::Auth("AWS_ACCESS_KEY_ID not set".into()))?;
193 let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY")
194 .map_err(|_| SmsError::Auth("AWS_SECRET_ACCESS_KEY not set".into()))?;
195 Ok(Self::new(region, access_key_id, secret_access_key))
196 }
197
198 pub async fn with_default_credentials(region: impl Into<String>) -> Self {
204 let region_str = region.into();
205 let aws_region = Region::new(region_str.clone());
206 let config = aws_config::defaults(BehaviorVersion::latest())
207 .region(aws_region)
208 .load()
209 .await;
210
211 let client = SnsClient::new(&config);
212
213 Self {
214 client,
215 region: region_str,
216 }
217 }
218}
219
220#[async_trait]
221impl SmsClient for AwsSnsClient {
222 async fn send(&self, req: SendRequest<'_>) -> Result<SendResponse, SmsError> {
223 info!("Sending SMS via AWS SNS to {}", req.to);
224
225 let mut message_attributes = HashMap::new();
226
227 message_attributes.insert(
228 "AWS.SNS.SMS.SMSType".to_string(),
229 aws_sdk_sns::types::MessageAttributeValue::builder()
230 .data_type("String")
231 .string_value("Transactional")
232 .build()
233 .map_err(|e| {
234 SmsError::Provider(format!("Failed to build SMS type attribute: {}", e))
235 })?,
236 );
237
238 if !req.from.is_empty() && !req.from.starts_with('+') {
239 message_attributes.insert(
240 "AWS.SNS.SMS.SenderID".to_string(),
241 aws_sdk_sns::types::MessageAttributeValue::builder()
242 .data_type("String")
243 .string_value(req.from)
244 .build()
245 .map_err(|e| {
246 SmsError::Provider(format!("Failed to build sender ID attribute: {}", e))
247 })?,
248 );
249 }
250
251 debug!(
252 "Sending SNS message with attributes: {:?}",
253 message_attributes
254 );
255
256 let result = self
257 .client
258 .publish()
259 .phone_number(req.to)
260 .message(req.text)
261 .set_message_attributes(Some(message_attributes))
262 .send()
263 .await
264 .map_err(|e| {
265 error!("AWS SNS publish failed: {}", e);
266 match e.into_service_error() {
267 aws_sdk_sns::operation::publish::PublishError::AuthorizationErrorException(_) => {
268 SmsError::Auth("AWS authorization failed".to_string())
269 }
270 aws_sdk_sns::operation::publish::PublishError::InvalidParameterException(e) => {
271 SmsError::Invalid(e.message().unwrap_or("Invalid parameter").to_string())
272 }
273 aws_sdk_sns::operation::publish::PublishError::InvalidParameterValueException(e) => {
274 SmsError::Invalid(e.message().unwrap_or("Invalid parameter value").to_string())
275 }
276 e => SmsError::Provider(format!("AWS SNS error: {}", e)),
277 }
278 })?;
279
280 let message_id = result.message_id().unwrap_or_default().to_string();
281
282 info!(
283 "SMS sent successfully via AWS SNS with MessageId: {}",
284 message_id
285 );
286
287 let raw_json = serde_json::json!({
288 "MessageId": message_id,
289 "Region": self.region,
290 "ResponseMetadata": {
291 "HTTPStatusCode": 200
292 }
293 });
294
295 Ok(SendResponse {
296 id: message_id,
297 provider: "aws-sns",
298 raw: raw_json,
299 })
300 }
301}
302
303#[async_trait]
304impl InboundWebhook for AwsSnsClient {
305 fn provider(&self) -> &'static str {
306 "aws-sns"
307 }
308
309 fn parse_inbound(&self, headers: &Headers, body: &[u8]) -> Result<InboundMessage, SmsError> {
310 debug!("Parsing AWS SNS webhook");
311
312 let payload_str = String::from_utf8(body.to_vec()).map_err(|e| {
313 error!("Invalid UTF-8 in AWS SNS webhook: {}", e);
314 SmsError::Provider(format!("Invalid UTF-8: {}", e))
315 })?;
316
317 if let Some(signature) = headers.iter().find_map(|(k, v)| {
318 if k.eq_ignore_ascii_case("x-amz-sns-message-type") {
319 Some(v.as_str())
320 } else {
321 None
322 }
323 }) {
324 debug!("SNS message type: {}", signature);
325 }
326
327 let notification: SnsDeliveryNotification =
328 serde_json::from_str(&payload_str).map_err(|e| {
329 error!("Failed to parse SNS notification: {}", e);
330 SmsError::Provider(format!("Invalid notification format: {}", e))
331 })?;
332
333 if notification.notification_type == "Notification" {
334 if let Ok(delivery_report) =
335 serde_json::from_str::<SmsDeliveryReport>(¬ification.message)
336 {
337 info!(
338 "Received SMS delivery report for message: {}",
339 delivery_report.message_id
340 );
341
342 let timestamp = time::OffsetDateTime::parse(
343 ¬ification.timestamp,
344 &time::format_description::well_known::Rfc3339,
345 )
346 .ok();
347
348 let raw_json = serde_json::to_value(¬ification)
349 .map_err(|e| SmsError::Provider(format!("JSON serialization error: {}", e)))?;
350
351 return Ok(InboundMessage {
352 id: Some(delivery_report.message_id),
353 from: "AWS-SNS".to_string(),
354 to: delivery_report.destination_phone_number,
355 text: format!("Delivery Status: {}", delivery_report.status),
356 timestamp,
357 provider: "aws-sns",
358 raw: raw_json,
359 });
360 }
361 }
362
363 if notification.notification_type == "SubscriptionConfirmation" {
364 warn!("Received SNS subscription confirmation, manual confirmation may be required");
365
366 let raw_json = serde_json::to_value(¬ification)
367 .map_err(|e| SmsError::Provider(format!("JSON serialization error: {}", e)))?;
368
369 let timestamp = time::OffsetDateTime::parse(
370 ¬ification.timestamp,
371 &time::format_description::well_known::Rfc3339,
372 )
373 .ok();
374
375 return Ok(InboundMessage {
376 id: Some(notification.message_id),
377 from: "AWS-SNS".to_string(),
378 to: "SYSTEM".to_string(),
379 text: "Subscription confirmation required".to_string(),
380 timestamp,
381 provider: "aws-sns",
382 raw: raw_json,
383 });
384 }
385
386 error!(
387 "Unknown SNS notification type: {}",
388 notification.notification_type
389 );
390 Err(SmsError::Provider(format!(
391 "Unsupported notification type: {}",
392 notification.notification_type
393 )))
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
404 fn client_creation() {
405 let client = AwsSnsClient::new("us-east-1", "test_key", "test_secret");
406 assert_eq!(client.region, "us-east-1");
407 }
408
409 #[test]
410 fn client_creation_different_region() {
411 let client = AwsSnsClient::new("eu-west-1", "key", "secret");
412 assert_eq!(client.region, "eu-west-1");
413 }
414
415 #[test]
421 fn from_env_scenarios() {
422 unsafe {
424 std::env::remove_var("AWS_REGION");
425 std::env::remove_var("AWS_DEFAULT_REGION");
426 std::env::remove_var("AWS_ACCESS_KEY_ID");
427 std::env::remove_var("AWS_SECRET_ACCESS_KEY");
428 }
429 let err = AwsSnsClient::from_env().unwrap_err();
430 assert!(err.to_string().contains("AWS_REGION"));
431
432 unsafe { std::env::set_var("AWS_REGION", "us-east-1"); }
434 let err = AwsSnsClient::from_env().unwrap_err();
435 assert!(err.to_string().contains("AWS_ACCESS_KEY_ID"));
436
437 unsafe { std::env::set_var("AWS_ACCESS_KEY_ID", "test-key"); }
439 let err = AwsSnsClient::from_env().unwrap_err();
440 assert!(err.to_string().contains("AWS_SECRET_ACCESS_KEY"));
441
442 unsafe { std::env::set_var("AWS_SECRET_ACCESS_KEY", "test-secret"); }
444 let client = AwsSnsClient::from_env().unwrap();
445 assert_eq!(client.region, "us-east-1");
446
447 unsafe {
449 std::env::remove_var("AWS_REGION");
450 std::env::set_var("AWS_DEFAULT_REGION", "ap-southeast-1");
451 }
452 let client = AwsSnsClient::from_env().unwrap();
453 assert_eq!(client.region, "ap-southeast-1");
454
455 unsafe {
457 std::env::remove_var("AWS_REGION");
458 std::env::remove_var("AWS_DEFAULT_REGION");
459 std::env::remove_var("AWS_ACCESS_KEY_ID");
460 std::env::remove_var("AWS_SECRET_ACCESS_KEY");
461 }
462 }
463
464 #[test]
467 fn provider_name() {
468 let client = AwsSnsClient::new("us-east-1", "test_key", "test_secret");
469 assert_eq!(client.provider(), "aws-sns");
470 }
471
472 fn delivery_report_json() -> String {
475 r#"{
476 "Type": "Notification",
477 "MessageId": "test-message-id",
478 "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic",
479 "Message": "{\"notification\":{\"messageId\":\"msg-123\",\"timestamp\":\"2023-01-01T00:00:00.000Z\"},\"delivery\":{\"destination\":\"+1234567890\",\"priceInUSD\":0.00645,\"smsType\":\"Transactional\"},\"status\":\"SUCCESS\",\"messageId\":\"msg-123\",\"destinationPhoneNumber\":\"+1234567890\"}",
480 "Timestamp": "2023-01-01T00:00:00.000Z",
481 "SignatureVersion": "1",
482 "Signature": "test-signature",
483 "SigningCertURL": "https://sns.us-east-1.amazonaws.com/test.pem"
484 }"#.to_string()
485 }
486
487 #[test]
488 fn webhook_parsing_delivery_report() {
489 let client = AwsSnsClient::new("us-east-1", "test_key", "test_secret");
490 let json = delivery_report_json();
491 let headers = vec![];
492 let result = client.parse_inbound(&headers, json.as_bytes());
493
494 assert!(result.is_ok());
495 let message = result.unwrap();
496 assert_eq!(message.id, Some("msg-123".to_string()));
497 assert_eq!(message.to, "+1234567890");
498 assert_eq!(message.provider, "aws-sns");
499 assert!(message.text.contains("SUCCESS"));
500 assert!(message.timestamp.is_some());
501 }
502
503 #[test]
504 fn webhook_delivery_report_from_field() {
505 let client = AwsSnsClient::new("us-east-1", "k", "s");
506 let json = delivery_report_json();
507 let msg = client.parse_inbound(&vec![], json.as_bytes()).unwrap();
508 assert_eq!(msg.from, "AWS-SNS");
509 }
510
511 #[test]
512 fn webhook_delivery_report_raw_contains_notification() {
513 let client = AwsSnsClient::new("us-east-1", "k", "s");
514 let json = delivery_report_json();
515 let msg = client.parse_inbound(&vec![], json.as_bytes()).unwrap();
516 assert!(msg.raw.get("TopicArn").is_some());
517 }
518
519 fn subscription_confirmation_json() -> String {
522 r#"{
523 "Type": "SubscriptionConfirmation",
524 "MessageId": "subscription-message-id",
525 "TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic",
526 "Message": "You have chosen to subscribe to the topic...",
527 "Timestamp": "2023-01-01T00:00:00.000Z",
528 "SignatureVersion": "1",
529 "Signature": "test-signature",
530 "SigningCertURL": "https://sns.us-east-1.amazonaws.com/test.pem"
531 }"#.to_string()
532 }
533
534 #[test]
535 fn webhook_parsing_subscription_confirmation() {
536 let client = AwsSnsClient::new("us-east-1", "test_key", "test_secret");
537 let json = subscription_confirmation_json();
538 let result = client.parse_inbound(&vec![], json.as_bytes());
539
540 assert!(result.is_ok());
541 let message = result.unwrap();
542 assert_eq!(message.id, Some("subscription-message-id".to_string()));
543 assert_eq!(message.text, "Subscription confirmation required");
544 assert_eq!(message.to, "SYSTEM");
545 assert_eq!(message.provider, "aws-sns");
546 }
547
548 #[test]
551 fn webhook_parsing_unknown_type_errors() {
552 let client = AwsSnsClient::new("us-east-1", "k", "s");
553 let json = r#"{
554 "Type": "SomethingNew",
555 "MessageId": "id",
556 "TopicArn": "arn",
557 "Message": "...",
558 "Timestamp": "2023-01-01T00:00:00.000Z",
559 "SignatureVersion": "1",
560 "Signature": "sig",
561 "SigningCertURL": "https://example.com/cert.pem"
562 }"#;
563 let result = client.parse_inbound(&vec![], json.as_bytes());
564 assert!(result.is_err());
565 assert!(result.unwrap_err().to_string().contains("Unsupported notification type"));
566 }
567
568 #[test]
571 fn webhook_parsing_invalid_json() {
572 let client = AwsSnsClient::new("us-east-1", "k", "s");
573 let result = client.parse_inbound(&vec![], b"not json");
574 assert!(result.is_err());
575 }
576
577 #[test]
580 fn webhook_parsing_invalid_utf8() {
581 let client = AwsSnsClient::new("us-east-1", "k", "s");
582 let result = client.parse_inbound(&vec![], &[0xFF, 0xFE]);
583 assert!(result.is_err());
584 assert!(result.unwrap_err().to_string().contains("UTF-8"));
585 }
586
587 #[test]
590 fn webhook_with_message_type_header() {
591 let client = AwsSnsClient::new("us-east-1", "k", "s");
592 let json = subscription_confirmation_json();
593 let headers = vec![(
594 "x-amz-sns-message-type".to_string(),
595 "SubscriptionConfirmation".to_string(),
596 )];
597 let result = client.parse_inbound(&headers, json.as_bytes());
598 assert!(result.is_ok());
599 }
600
601 #[test]
604 fn webhook_delivery_report_failure_status() {
605 let client = AwsSnsClient::new("us-east-1", "k", "s");
606 let json = r#"{
607 "Type": "Notification",
608 "MessageId": "test-id",
609 "TopicArn": "arn:aws:sns:us-east-1:123:topic",
610 "Message": "{\"notification\":{\"messageId\":\"msg-fail\",\"timestamp\":\"2023-06-15T10:00:00.000Z\"},\"delivery\":{\"destination\":\"+19875551234\",\"smsType\":\"Transactional\"},\"status\":\"FAILURE\",\"messageId\":\"msg-fail\",\"destinationPhoneNumber\":\"+19875551234\"}",
611 "Timestamp": "2023-06-15T10:00:00.000Z",
612 "SignatureVersion": "1",
613 "Signature": "sig",
614 "SigningCertURL": "https://sns.us-east-1.amazonaws.com/cert.pem"
615 }"#;
616 let msg = client.parse_inbound(&vec![], json.as_bytes()).unwrap();
617 assert!(msg.text.contains("FAILURE"));
618 assert_eq!(msg.id, Some("msg-fail".into()));
619 }
620
621 #[test]
624 fn webhook_notification_with_non_delivery_message() {
625 let client = AwsSnsClient::new("us-east-1", "k", "s");
626 let json = r#"{
628 "Type": "Notification",
629 "MessageId": "notif-id",
630 "TopicArn": "arn:aws:sns:us-east-1:123:topic",
631 "Message": "This is a plain text notification, not a delivery report",
632 "Timestamp": "2023-01-01T00:00:00.000Z",
633 "SignatureVersion": "1",
634 "Signature": "sig",
635 "SigningCertURL": "https://sns.us-east-1.amazonaws.com/cert.pem"
636 }"#;
637 let result = client.parse_inbound(&vec![], json.as_bytes());
645 assert!(result.is_err());
646 assert!(result.unwrap_err().to_string().contains("Unsupported notification type"));
647 }
648
649 #[test]
652 fn sns_notification_serde_roundtrip() {
653 let json = delivery_report_json();
654 let notif: SnsDeliveryNotification = serde_json::from_str(&json).unwrap();
655 assert_eq!(notif.notification_type, "Notification");
656 assert_eq!(notif.message_id, "test-message-id");
657
658 let reserialized = serde_json::to_string(¬if).unwrap();
659 let notif2: SnsDeliveryNotification = serde_json::from_str(&reserialized).unwrap();
660 assert_eq!(notif2.message_id, notif.message_id);
661 }
662
663 #[test]
666 fn delivery_report_serde() {
667 let inner = r#"{"notification":{"messageId":"m1","timestamp":"2023-01-01T00:00:00Z"},"delivery":{"destination":"+1","priceInUSD":0.005,"smsType":"Transactional","dwellTimeMs":100,"dwellTimeMsUntilDeviceAck":200},"status":"SUCCESS","messageId":"m1","destinationPhoneNumber":"+1"}"#;
668 let report: SmsDeliveryReport = serde_json::from_str(inner).unwrap();
669 assert_eq!(report.status, "SUCCESS");
670 assert_eq!(report.message_id, "m1");
671 assert_eq!(report.delivery.price_in_usd, Some(0.005));
672 assert_eq!(report.delivery.dwell_time_ms, Some(100));
673 }
674
675 #[test]
678 fn owned_request_can_be_borrowed_for_send() {
679 let owned = sms_core::OwnedSendRequest::new("+14155551234", "MySenderID", "Hello SNS!");
680 let borrowed = owned.as_ref();
681 assert_eq!(borrowed.to, "+14155551234");
682 assert_eq!(borrowed.from, "MySenderID");
683 assert!(!borrowed.from.starts_with('+'));
684 }
685}