Skip to main content

rustant_core/channels/
sms.rs

1//! SMS channel via Twilio API.
2//!
3//! Uses the Twilio REST API via reqwest for sending and receiving SMS.
4//! In tests, a trait abstraction provides mock implementations.
5
6use super::{
7    Channel, ChannelCapabilities, ChannelMessage, ChannelStatus, ChannelType, ChannelUser,
8    MessageId, StreamingMode,
9};
10use crate::error::{ChannelError, RustantError};
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13
14/// Configuration for an SMS channel.
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
16pub struct SmsConfig {
17    pub enabled: bool,
18    pub account_sid: String,
19    pub auth_token: String,
20    pub from_number: String,
21    pub polling_interval_ms: u64,
22}
23
24/// Trait for SMS API interactions.
25#[async_trait]
26pub trait SmsHttpClient: Send + Sync {
27    async fn send_sms(&self, to: &str, body: &str) -> Result<String, String>;
28    async fn get_messages(&self) -> Result<Vec<SmsIncoming>, String>;
29}
30
31/// An incoming SMS message.
32#[derive(Debug, Clone)]
33pub struct SmsIncoming {
34    pub sid: String,
35    pub from: String,
36    pub body: String,
37}
38
39/// SMS channel.
40pub struct SmsChannel {
41    config: SmsConfig,
42    status: ChannelStatus,
43    http_client: Box<dyn SmsHttpClient>,
44    name: String,
45}
46
47impl SmsChannel {
48    pub fn new(config: SmsConfig, http_client: Box<dyn SmsHttpClient>) -> Self {
49        Self {
50            config,
51            status: ChannelStatus::Disconnected,
52            http_client,
53            name: "sms".to_string(),
54        }
55    }
56
57    pub fn with_name(mut self, name: impl Into<String>) -> Self {
58        self.name = name.into();
59        self
60    }
61}
62
63#[async_trait]
64impl Channel for SmsChannel {
65    fn name(&self) -> &str {
66        &self.name
67    }
68
69    fn channel_type(&self) -> ChannelType {
70        ChannelType::Sms
71    }
72
73    async fn connect(&mut self) -> Result<(), RustantError> {
74        if self.config.account_sid.is_empty() || self.config.auth_token.is_empty() {
75            return Err(RustantError::Channel(ChannelError::AuthFailed {
76                name: self.name.clone(),
77            }));
78        }
79        self.status = ChannelStatus::Connected;
80        Ok(())
81    }
82
83    async fn disconnect(&mut self) -> Result<(), RustantError> {
84        self.status = ChannelStatus::Disconnected;
85        Ok(())
86    }
87
88    async fn send_message(&self, msg: ChannelMessage) -> Result<MessageId, RustantError> {
89        if self.status != ChannelStatus::Connected {
90            return Err(RustantError::Channel(ChannelError::NotConnected {
91                name: self.name.clone(),
92            }));
93        }
94        let text = msg.content.as_text().unwrap_or("");
95        self.http_client
96            .send_sms(&msg.channel_id, text)
97            .await
98            .map(MessageId::new)
99            .map_err(|e| {
100                RustantError::Channel(ChannelError::SendFailed {
101                    name: self.name.clone(),
102                    message: e,
103                })
104            })
105    }
106
107    async fn receive_messages(&self) -> Result<Vec<ChannelMessage>, RustantError> {
108        let incoming = self.http_client.get_messages().await.map_err(|e| {
109            RustantError::Channel(ChannelError::ConnectionFailed {
110                name: self.name.clone(),
111                message: e,
112            })
113        })?;
114
115        let messages = incoming
116            .into_iter()
117            .map(|m| {
118                let sender = ChannelUser::new(&m.from, ChannelType::Sms);
119                ChannelMessage::text(ChannelType::Sms, &m.from, sender, &m.body)
120            })
121            .collect();
122
123        Ok(messages)
124    }
125
126    fn status(&self) -> ChannelStatus {
127        self.status
128    }
129
130    fn capabilities(&self) -> ChannelCapabilities {
131        ChannelCapabilities {
132            supports_threads: false,
133            supports_reactions: false,
134            supports_files: false,
135            supports_voice: false,
136            supports_video: false,
137            max_message_length: Some(1600),
138            supports_editing: false,
139            supports_deletion: false,
140        }
141    }
142
143    fn streaming_mode(&self) -> StreamingMode {
144        StreamingMode::Polling {
145            interval_ms: self.config.polling_interval_ms.max(1000),
146        }
147    }
148}
149
150/// Real Twilio SMS HTTP client using reqwest.
151pub struct RealSmsHttp {
152    client: reqwest::Client,
153    account_sid: String,
154    auth_token: String,
155    from_number: String,
156}
157
158impl RealSmsHttp {
159    pub fn new(account_sid: String, auth_token: String, from_number: String) -> Self {
160        Self {
161            client: reqwest::Client::new(),
162            account_sid,
163            auth_token,
164            from_number,
165        }
166    }
167}
168
169#[async_trait]
170impl SmsHttpClient for RealSmsHttp {
171    async fn send_sms(&self, to: &str, body: &str) -> Result<String, String> {
172        let url = format!(
173            "https://api.twilio.com/2010-04-01/Accounts/{}/Messages.json",
174            self.account_sid
175        );
176        let auth = base64::Engine::encode(
177            &base64::engine::general_purpose::STANDARD,
178            format!("{}:{}", self.account_sid, self.auth_token),
179        );
180        let resp = self
181            .client
182            .post(&url)
183            .header("Authorization", format!("Basic {}", auth))
184            .form(&[("To", to), ("From", &self.from_number), ("Body", body)])
185            .send()
186            .await
187            .map_err(|e| format!("HTTP error: {e}"))?;
188
189        let status = resp.status();
190        let body: serde_json::Value = resp
191            .json()
192            .await
193            .map_err(|e| format!("JSON parse error: {e}"))?;
194
195        if !status.is_success() {
196            let msg = body["message"].as_str().unwrap_or("unknown error");
197            return Err(format!("Twilio API error ({}): {}", status, msg));
198        }
199
200        let sid = body["sid"].as_str().unwrap_or("unknown").to_string();
201        Ok(sid)
202    }
203
204    async fn get_messages(&self) -> Result<Vec<SmsIncoming>, String> {
205        let url = format!(
206            "https://api.twilio.com/2010-04-01/Accounts/{}/Messages.json?PageSize=20",
207            self.account_sid
208        );
209        let auth = base64::Engine::encode(
210            &base64::engine::general_purpose::STANDARD,
211            format!("{}:{}", self.account_sid, self.auth_token),
212        );
213        let resp = self
214            .client
215            .get(&url)
216            .header("Authorization", format!("Basic {}", auth))
217            .send()
218            .await
219            .map_err(|e| format!("HTTP error: {e}"))?;
220
221        let body: serde_json::Value = resp
222            .json()
223            .await
224            .map_err(|e| format!("JSON parse error: {e}"))?;
225
226        let messages = body["messages"]
227            .as_array()
228            .unwrap_or(&Vec::new())
229            .iter()
230            .filter(|m| m["direction"].as_str() == Some("inbound"))
231            .filter_map(|m| {
232                Some(SmsIncoming {
233                    sid: m["sid"].as_str()?.to_string(),
234                    from: m["from"].as_str()?.to_string(),
235                    body: m["body"].as_str().unwrap_or("").to_string(),
236                })
237            })
238            .collect();
239
240        Ok(messages)
241    }
242}
243
244/// Create an SMS channel with a real Twilio HTTP client.
245pub fn create_sms_channel(config: SmsConfig) -> SmsChannel {
246    let http = RealSmsHttp::new(
247        config.account_sid.clone(),
248        config.auth_token.clone(),
249        config.from_number.clone(),
250    );
251    SmsChannel::new(config, Box::new(http))
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    struct MockSmsHttp;
259
260    #[async_trait]
261    impl SmsHttpClient for MockSmsHttp {
262        async fn send_sms(&self, _to: &str, _body: &str) -> Result<String, String> {
263            Ok("SM123".into())
264        }
265        async fn get_messages(&self) -> Result<Vec<SmsIncoming>, String> {
266            Ok(vec![])
267        }
268    }
269
270    #[test]
271    fn test_sms_channel_creation() {
272        let ch = SmsChannel::new(SmsConfig::default(), Box::new(MockSmsHttp));
273        assert_eq!(ch.name(), "sms");
274        assert_eq!(ch.channel_type(), ChannelType::Sms);
275    }
276
277    #[test]
278    fn test_sms_capabilities() {
279        let ch = SmsChannel::new(SmsConfig::default(), Box::new(MockSmsHttp));
280        let caps = ch.capabilities();
281        assert!(!caps.supports_threads);
282        assert!(!caps.supports_files);
283        assert_eq!(caps.max_message_length, Some(1600));
284    }
285
286    #[test]
287    fn test_sms_streaming_mode() {
288        let ch = SmsChannel::new(SmsConfig::default(), Box::new(MockSmsHttp));
289        assert_eq!(
290            ch.streaming_mode(),
291            StreamingMode::Polling { interval_ms: 1000 }
292        );
293    }
294
295    #[test]
296    fn test_sms_status_disconnected() {
297        let ch = SmsChannel::new(SmsConfig::default(), Box::new(MockSmsHttp));
298        assert_eq!(ch.status(), ChannelStatus::Disconnected);
299    }
300
301    #[tokio::test]
302    async fn test_sms_send_without_connect() {
303        let ch = SmsChannel::new(SmsConfig::default(), Box::new(MockSmsHttp));
304        let sender = ChannelUser::new("bot", ChannelType::Sms);
305        let msg = ChannelMessage::text(ChannelType::Sms, "+1234", sender, "hi");
306        assert!(ch.send_message(msg).await.is_err());
307    }
308}