Skip to main content

rustant_core/channels/
matrix.rs

1//! Matrix Client-Server API channel implementation.
2//!
3//! Connects to a Matrix homeserver via the Client-Server API using reqwest.
4//! In tests, a trait abstraction provides mock implementations.
5
6use super::{
7    Channel, ChannelCapabilities, ChannelMessage, ChannelStatus, ChannelType, ChannelUser,
8    MessageId, StreamingMode,
9};
10use crate::error::{ChannelError, RustantError};
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13
14/// Configuration for a Matrix channel.
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
16pub struct MatrixConfig {
17    pub homeserver_url: String,
18    pub access_token: String,
19    pub user_id: String,
20    pub room_ids: Vec<String>,
21}
22
23/// Trait for Matrix API interactions.
24#[async_trait]
25pub trait MatrixHttpClient: Send + Sync {
26    async fn send_message(&self, room_id: &str, text: &str) -> Result<String, String>;
27    async fn sync(&self, since: Option<&str>) -> Result<Vec<MatrixEvent>, String>;
28    async fn login(&self) -> Result<String, String>;
29}
30
31/// A Matrix room event.
32#[derive(Debug, Clone)]
33pub struct MatrixEvent {
34    pub event_id: String,
35    pub room_id: String,
36    pub sender: String,
37    pub body: String,
38}
39
40/// Matrix channel.
41pub struct MatrixChannel {
42    config: MatrixConfig,
43    status: ChannelStatus,
44    http_client: Box<dyn MatrixHttpClient>,
45    name: String,
46    next_batch: Option<String>,
47}
48
49impl MatrixChannel {
50    pub fn new(config: MatrixConfig, http_client: Box<dyn MatrixHttpClient>) -> Self {
51        Self {
52            config,
53            status: ChannelStatus::Disconnected,
54            http_client,
55            name: "matrix".to_string(),
56            next_batch: None,
57        }
58    }
59
60    pub fn with_name(mut self, name: impl Into<String>) -> Self {
61        self.name = name.into();
62        self
63    }
64}
65
66#[async_trait]
67impl Channel for MatrixChannel {
68    fn name(&self) -> &str {
69        &self.name
70    }
71
72    fn channel_type(&self) -> ChannelType {
73        ChannelType::Matrix
74    }
75
76    async fn connect(&mut self) -> Result<(), RustantError> {
77        if self.config.homeserver_url.is_empty() || self.config.access_token.is_empty() {
78            return Err(RustantError::Channel(ChannelError::AuthFailed {
79                name: self.name.clone(),
80            }));
81        }
82        self.http_client.login().await.map_err(|e| {
83            RustantError::Channel(ChannelError::ConnectionFailed {
84                name: self.name.clone(),
85                message: e,
86            })
87        })?;
88        self.status = ChannelStatus::Connected;
89        Ok(())
90    }
91
92    async fn disconnect(&mut self) -> Result<(), RustantError> {
93        self.status = ChannelStatus::Disconnected;
94        Ok(())
95    }
96
97    async fn send_message(&self, msg: ChannelMessage) -> Result<MessageId, RustantError> {
98        let text = msg.content.as_text().unwrap_or("");
99        self.http_client
100            .send_message(&msg.channel_id, text)
101            .await
102            .map(MessageId::new)
103            .map_err(|e| {
104                RustantError::Channel(ChannelError::SendFailed {
105                    name: self.name.clone(),
106                    message: e,
107                })
108            })
109    }
110
111    async fn receive_messages(&self) -> Result<Vec<ChannelMessage>, RustantError> {
112        let events = self
113            .http_client
114            .sync(self.next_batch.as_deref())
115            .await
116            .map_err(|e| {
117                RustantError::Channel(ChannelError::ConnectionFailed {
118                    name: self.name.clone(),
119                    message: e,
120                })
121            })?;
122
123        let messages = events
124            .into_iter()
125            .filter(|e| {
126                self.config.room_ids.is_empty() || self.config.room_ids.contains(&e.room_id)
127            })
128            .map(|e| {
129                let sender = ChannelUser::new(&e.sender, ChannelType::Matrix);
130                ChannelMessage::text(ChannelType::Matrix, &e.room_id, sender, &e.body)
131            })
132            .collect();
133
134        Ok(messages)
135    }
136
137    fn status(&self) -> ChannelStatus {
138        self.status
139    }
140
141    fn capabilities(&self) -> ChannelCapabilities {
142        ChannelCapabilities {
143            supports_threads: true,
144            supports_reactions: true,
145            supports_files: true,
146            supports_voice: false,
147            supports_video: false,
148            max_message_length: None,
149            supports_editing: true,
150            supports_deletion: true,
151        }
152    }
153
154    fn streaming_mode(&self) -> StreamingMode {
155        StreamingMode::LongPolling
156    }
157}
158
159/// Real Matrix Client-Server API HTTP client using reqwest.
160pub struct RealMatrixHttp {
161    client: reqwest::Client,
162    homeserver_url: String,
163    access_token: String,
164    txn_counter: std::sync::atomic::AtomicU64,
165}
166
167impl RealMatrixHttp {
168    pub fn new(homeserver_url: String, access_token: String) -> Self {
169        Self {
170            client: reqwest::Client::new(),
171            homeserver_url: homeserver_url.trim_end_matches('/').to_string(),
172            access_token,
173            txn_counter: std::sync::atomic::AtomicU64::new(0),
174        }
175    }
176}
177
178#[async_trait]
179impl MatrixHttpClient for RealMatrixHttp {
180    async fn send_message(&self, room_id: &str, text: &str) -> Result<String, String> {
181        let txn_id = self
182            .txn_counter
183            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
184        let url = format!(
185            "{}/_matrix/client/v3/rooms/{}/send/m.room.message/{}",
186            self.homeserver_url, room_id, txn_id
187        );
188        let resp = self
189            .client
190            .put(&url)
191            .header("Authorization", format!("Bearer {}", self.access_token))
192            .json(&serde_json::json!({
193                "msgtype": "m.text",
194                "body": text,
195            }))
196            .send()
197            .await
198            .map_err(|e| format!("HTTP error: {e}"))?;
199
200        let status = resp.status();
201        let body: serde_json::Value = resp
202            .json()
203            .await
204            .map_err(|e| format!("JSON parse error: {e}"))?;
205
206        if !status.is_success() {
207            let err = body["error"].as_str().unwrap_or("unknown error");
208            return Err(format!("Matrix API error ({}): {}", status, err));
209        }
210
211        let event_id = body["event_id"].as_str().unwrap_or("").to_string();
212        Ok(event_id)
213    }
214
215    async fn sync(&self, since: Option<&str>) -> Result<Vec<MatrixEvent>, String> {
216        let mut url = format!(
217            "{}/_matrix/client/v3/sync?timeout=30000",
218            self.homeserver_url
219        );
220        if let Some(since) = since {
221            url.push_str(&format!("&since={}", since));
222        }
223        let resp = self
224            .client
225            .get(&url)
226            .header("Authorization", format!("Bearer {}", self.access_token))
227            .send()
228            .await
229            .map_err(|e| format!("HTTP error: {e}"))?;
230
231        let body: serde_json::Value = resp
232            .json()
233            .await
234            .map_err(|e| format!("JSON parse error: {e}"))?;
235
236        let mut events = Vec::new();
237        if let Some(rooms) = body["rooms"]["join"].as_object() {
238            for (room_id, room_data) in rooms {
239                if let Some(timeline) = room_data["timeline"]["events"].as_array() {
240                    for event in timeline {
241                        if event["type"].as_str() == Some("m.room.message")
242                            && let Some(event_body) = event["content"]["body"].as_str()
243                        {
244                            events.push(MatrixEvent {
245                                event_id: event["event_id"].as_str().unwrap_or("").to_string(),
246                                room_id: room_id.clone(),
247                                sender: event["sender"].as_str().unwrap_or("").to_string(),
248                                body: event_body.to_string(),
249                            });
250                        }
251                    }
252                }
253            }
254        }
255
256        Ok(events)
257    }
258
259    async fn login(&self) -> Result<String, String> {
260        // When using an access token, verify it with whoami
261        let url = format!("{}/_matrix/client/v3/account/whoami", self.homeserver_url);
262        let resp = self
263            .client
264            .get(&url)
265            .header("Authorization", format!("Bearer {}", self.access_token))
266            .send()
267            .await
268            .map_err(|e| format!("HTTP error: {e}"))?;
269
270        let status = resp.status();
271        let body: serde_json::Value = resp
272            .json()
273            .await
274            .map_err(|e| format!("JSON parse error: {e}"))?;
275
276        if !status.is_success() {
277            let err = body["error"].as_str().unwrap_or("unauthorized");
278            return Err(format!("Matrix auth failed ({}): {}", status, err));
279        }
280
281        let user_id = body["user_id"].as_str().unwrap_or("").to_string();
282        Ok(user_id)
283    }
284}
285
286/// Create a Matrix channel with a real HTTP client.
287pub fn create_matrix_channel(config: MatrixConfig) -> MatrixChannel {
288    let http = RealMatrixHttp::new(config.homeserver_url.clone(), config.access_token.clone());
289    MatrixChannel::new(config, Box::new(http))
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    struct MockMatrixHttp;
297
298    #[async_trait]
299    impl MatrixHttpClient for MockMatrixHttp {
300        async fn send_message(&self, _room_id: &str, _text: &str) -> Result<String, String> {
301            Ok("$event1".to_string())
302        }
303        async fn sync(&self, _since: Option<&str>) -> Result<Vec<MatrixEvent>, String> {
304            Ok(vec![MatrixEvent {
305                event_id: "$ev1".into(),
306                room_id: "!room1:example.com".into(),
307                sender: "@alice:example.com".into(),
308                body: "hello matrix".into(),
309            }])
310        }
311        async fn login(&self) -> Result<String, String> {
312            Ok("token123".to_string())
313        }
314    }
315
316    #[tokio::test]
317    async fn test_matrix_connect() {
318        let config = MatrixConfig {
319            homeserver_url: "https://matrix.example.com".into(),
320            access_token: "token".into(),
321            user_id: "@bot:example.com".into(),
322            room_ids: vec![],
323        };
324        let mut ch = MatrixChannel::new(config, Box::new(MockMatrixHttp));
325        ch.connect().await.unwrap();
326        assert!(ch.is_connected());
327    }
328
329    #[tokio::test]
330    async fn test_matrix_send() {
331        let config = MatrixConfig {
332            homeserver_url: "https://matrix.example.com".into(),
333            access_token: "token".into(),
334            ..Default::default()
335        };
336        let mut ch = MatrixChannel::new(config, Box::new(MockMatrixHttp));
337        ch.connect().await.unwrap();
338
339        let sender = ChannelUser::new("@bot:ex.com", ChannelType::Matrix);
340        let msg = ChannelMessage::text(ChannelType::Matrix, "!room1:ex.com", sender, "hi");
341        let id = ch.send_message(msg).await.unwrap();
342        assert_eq!(id.0, "$event1");
343    }
344
345    #[tokio::test]
346    async fn test_matrix_receive() {
347        let config = MatrixConfig {
348            homeserver_url: "https://matrix.example.com".into(),
349            access_token: "token".into(),
350            ..Default::default()
351        };
352        let mut ch = MatrixChannel::new(config, Box::new(MockMatrixHttp));
353        ch.connect().await.unwrap();
354
355        let msgs = ch.receive_messages().await.unwrap();
356        assert_eq!(msgs.len(), 1);
357        assert_eq!(msgs[0].content.as_text(), Some("hello matrix"));
358    }
359
360    #[test]
361    fn test_matrix_capabilities() {
362        let ch = MatrixChannel::new(MatrixConfig::default(), Box::new(MockMatrixHttp));
363        let caps = ch.capabilities();
364        assert!(caps.supports_threads);
365        assert!(caps.supports_reactions);
366        assert!(caps.supports_files);
367        assert!(caps.supports_editing);
368        assert!(caps.supports_deletion);
369        assert!(caps.max_message_length.is_none());
370    }
371
372    #[test]
373    fn test_matrix_streaming_mode() {
374        let ch = MatrixChannel::new(MatrixConfig::default(), Box::new(MockMatrixHttp));
375        assert_eq!(ch.streaming_mode(), StreamingMode::LongPolling);
376    }
377}