1use super::{
7 Channel, ChannelCapabilities, ChannelMessage, ChannelStatus, ChannelType, ChannelUser,
8 MessageId, StreamingMode,
9};
10use crate::error::{ChannelError, RustantError};
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Default, Serialize, Deserialize)]
16pub struct MatrixConfig {
17 pub homeserver_url: String,
18 pub access_token: String,
19 pub user_id: String,
20 pub room_ids: Vec<String>,
21}
22
23#[async_trait]
25pub trait MatrixHttpClient: Send + Sync {
26 async fn send_message(&self, room_id: &str, text: &str) -> Result<String, String>;
27 async fn sync(&self, since: Option<&str>) -> Result<Vec<MatrixEvent>, String>;
28 async fn login(&self) -> Result<String, String>;
29}
30
31#[derive(Debug, Clone)]
33pub struct MatrixEvent {
34 pub event_id: String,
35 pub room_id: String,
36 pub sender: String,
37 pub body: String,
38}
39
40pub struct MatrixChannel {
42 config: MatrixConfig,
43 status: ChannelStatus,
44 http_client: Box<dyn MatrixHttpClient>,
45 name: String,
46 next_batch: Option<String>,
47}
48
49impl MatrixChannel {
50 pub fn new(config: MatrixConfig, http_client: Box<dyn MatrixHttpClient>) -> Self {
51 Self {
52 config,
53 status: ChannelStatus::Disconnected,
54 http_client,
55 name: "matrix".to_string(),
56 next_batch: None,
57 }
58 }
59
60 pub fn with_name(mut self, name: impl Into<String>) -> Self {
61 self.name = name.into();
62 self
63 }
64}
65
66#[async_trait]
67impl Channel for MatrixChannel {
68 fn name(&self) -> &str {
69 &self.name
70 }
71
72 fn channel_type(&self) -> ChannelType {
73 ChannelType::Matrix
74 }
75
76 async fn connect(&mut self) -> Result<(), RustantError> {
77 if self.config.homeserver_url.is_empty() || self.config.access_token.is_empty() {
78 return Err(RustantError::Channel(ChannelError::AuthFailed {
79 name: self.name.clone(),
80 }));
81 }
82 self.http_client.login().await.map_err(|e| {
83 RustantError::Channel(ChannelError::ConnectionFailed {
84 name: self.name.clone(),
85 message: e,
86 })
87 })?;
88 self.status = ChannelStatus::Connected;
89 Ok(())
90 }
91
92 async fn disconnect(&mut self) -> Result<(), RustantError> {
93 self.status = ChannelStatus::Disconnected;
94 Ok(())
95 }
96
97 async fn send_message(&self, msg: ChannelMessage) -> Result<MessageId, RustantError> {
98 let text = msg.content.as_text().unwrap_or("");
99 self.http_client
100 .send_message(&msg.channel_id, text)
101 .await
102 .map(MessageId::new)
103 .map_err(|e| {
104 RustantError::Channel(ChannelError::SendFailed {
105 name: self.name.clone(),
106 message: e,
107 })
108 })
109 }
110
111 async fn receive_messages(&self) -> Result<Vec<ChannelMessage>, RustantError> {
112 let events = self
113 .http_client
114 .sync(self.next_batch.as_deref())
115 .await
116 .map_err(|e| {
117 RustantError::Channel(ChannelError::ConnectionFailed {
118 name: self.name.clone(),
119 message: e,
120 })
121 })?;
122
123 let messages = events
124 .into_iter()
125 .filter(|e| {
126 self.config.room_ids.is_empty() || self.config.room_ids.contains(&e.room_id)
127 })
128 .map(|e| {
129 let sender = ChannelUser::new(&e.sender, ChannelType::Matrix);
130 ChannelMessage::text(ChannelType::Matrix, &e.room_id, sender, &e.body)
131 })
132 .collect();
133
134 Ok(messages)
135 }
136
137 fn status(&self) -> ChannelStatus {
138 self.status
139 }
140
141 fn capabilities(&self) -> ChannelCapabilities {
142 ChannelCapabilities {
143 supports_threads: true,
144 supports_reactions: true,
145 supports_files: true,
146 supports_voice: false,
147 supports_video: false,
148 max_message_length: None,
149 supports_editing: true,
150 supports_deletion: true,
151 }
152 }
153
154 fn streaming_mode(&self) -> StreamingMode {
155 StreamingMode::LongPolling
156 }
157}
158
159pub struct RealMatrixHttp {
161 client: reqwest::Client,
162 homeserver_url: String,
163 access_token: String,
164 txn_counter: std::sync::atomic::AtomicU64,
165}
166
167impl RealMatrixHttp {
168 pub fn new(homeserver_url: String, access_token: String) -> Self {
169 Self {
170 client: reqwest::Client::new(),
171 homeserver_url: homeserver_url.trim_end_matches('/').to_string(),
172 access_token,
173 txn_counter: std::sync::atomic::AtomicU64::new(0),
174 }
175 }
176}
177
178#[async_trait]
179impl MatrixHttpClient for RealMatrixHttp {
180 async fn send_message(&self, room_id: &str, text: &str) -> Result<String, String> {
181 let txn_id = self
182 .txn_counter
183 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
184 let url = format!(
185 "{}/_matrix/client/v3/rooms/{}/send/m.room.message/{}",
186 self.homeserver_url, room_id, txn_id
187 );
188 let resp = self
189 .client
190 .put(&url)
191 .header("Authorization", format!("Bearer {}", self.access_token))
192 .json(&serde_json::json!({
193 "msgtype": "m.text",
194 "body": text,
195 }))
196 .send()
197 .await
198 .map_err(|e| format!("HTTP error: {e}"))?;
199
200 let status = resp.status();
201 let body: serde_json::Value = resp
202 .json()
203 .await
204 .map_err(|e| format!("JSON parse error: {e}"))?;
205
206 if !status.is_success() {
207 let err = body["error"].as_str().unwrap_or("unknown error");
208 return Err(format!("Matrix API error ({}): {}", status, err));
209 }
210
211 let event_id = body["event_id"].as_str().unwrap_or("").to_string();
212 Ok(event_id)
213 }
214
215 async fn sync(&self, since: Option<&str>) -> Result<Vec<MatrixEvent>, String> {
216 let mut url = format!(
217 "{}/_matrix/client/v3/sync?timeout=30000",
218 self.homeserver_url
219 );
220 if let Some(since) = since {
221 url.push_str(&format!("&since={}", since));
222 }
223 let resp = self
224 .client
225 .get(&url)
226 .header("Authorization", format!("Bearer {}", self.access_token))
227 .send()
228 .await
229 .map_err(|e| format!("HTTP error: {e}"))?;
230
231 let body: serde_json::Value = resp
232 .json()
233 .await
234 .map_err(|e| format!("JSON parse error: {e}"))?;
235
236 let mut events = Vec::new();
237 if let Some(rooms) = body["rooms"]["join"].as_object() {
238 for (room_id, room_data) in rooms {
239 if let Some(timeline) = room_data["timeline"]["events"].as_array() {
240 for event in timeline {
241 if event["type"].as_str() == Some("m.room.message")
242 && let Some(event_body) = event["content"]["body"].as_str()
243 {
244 events.push(MatrixEvent {
245 event_id: event["event_id"].as_str().unwrap_or("").to_string(),
246 room_id: room_id.clone(),
247 sender: event["sender"].as_str().unwrap_or("").to_string(),
248 body: event_body.to_string(),
249 });
250 }
251 }
252 }
253 }
254 }
255
256 Ok(events)
257 }
258
259 async fn login(&self) -> Result<String, String> {
260 let url = format!("{}/_matrix/client/v3/account/whoami", self.homeserver_url);
262 let resp = self
263 .client
264 .get(&url)
265 .header("Authorization", format!("Bearer {}", self.access_token))
266 .send()
267 .await
268 .map_err(|e| format!("HTTP error: {e}"))?;
269
270 let status = resp.status();
271 let body: serde_json::Value = resp
272 .json()
273 .await
274 .map_err(|e| format!("JSON parse error: {e}"))?;
275
276 if !status.is_success() {
277 let err = body["error"].as_str().unwrap_or("unauthorized");
278 return Err(format!("Matrix auth failed ({}): {}", status, err));
279 }
280
281 let user_id = body["user_id"].as_str().unwrap_or("").to_string();
282 Ok(user_id)
283 }
284}
285
286pub fn create_matrix_channel(config: MatrixConfig) -> MatrixChannel {
288 let http = RealMatrixHttp::new(config.homeserver_url.clone(), config.access_token.clone());
289 MatrixChannel::new(config, Box::new(http))
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 struct MockMatrixHttp;
297
298 #[async_trait]
299 impl MatrixHttpClient for MockMatrixHttp {
300 async fn send_message(&self, _room_id: &str, _text: &str) -> Result<String, String> {
301 Ok("$event1".to_string())
302 }
303 async fn sync(&self, _since: Option<&str>) -> Result<Vec<MatrixEvent>, String> {
304 Ok(vec![MatrixEvent {
305 event_id: "$ev1".into(),
306 room_id: "!room1:example.com".into(),
307 sender: "@alice:example.com".into(),
308 body: "hello matrix".into(),
309 }])
310 }
311 async fn login(&self) -> Result<String, String> {
312 Ok("token123".to_string())
313 }
314 }
315
316 #[tokio::test]
317 async fn test_matrix_connect() {
318 let config = MatrixConfig {
319 homeserver_url: "https://matrix.example.com".into(),
320 access_token: "token".into(),
321 user_id: "@bot:example.com".into(),
322 room_ids: vec![],
323 };
324 let mut ch = MatrixChannel::new(config, Box::new(MockMatrixHttp));
325 ch.connect().await.unwrap();
326 assert!(ch.is_connected());
327 }
328
329 #[tokio::test]
330 async fn test_matrix_send() {
331 let config = MatrixConfig {
332 homeserver_url: "https://matrix.example.com".into(),
333 access_token: "token".into(),
334 ..Default::default()
335 };
336 let mut ch = MatrixChannel::new(config, Box::new(MockMatrixHttp));
337 ch.connect().await.unwrap();
338
339 let sender = ChannelUser::new("@bot:ex.com", ChannelType::Matrix);
340 let msg = ChannelMessage::text(ChannelType::Matrix, "!room1:ex.com", sender, "hi");
341 let id = ch.send_message(msg).await.unwrap();
342 assert_eq!(id.0, "$event1");
343 }
344
345 #[tokio::test]
346 async fn test_matrix_receive() {
347 let config = MatrixConfig {
348 homeserver_url: "https://matrix.example.com".into(),
349 access_token: "token".into(),
350 ..Default::default()
351 };
352 let mut ch = MatrixChannel::new(config, Box::new(MockMatrixHttp));
353 ch.connect().await.unwrap();
354
355 let msgs = ch.receive_messages().await.unwrap();
356 assert_eq!(msgs.len(), 1);
357 assert_eq!(msgs[0].content.as_text(), Some("hello matrix"));
358 }
359
360 #[test]
361 fn test_matrix_capabilities() {
362 let ch = MatrixChannel::new(MatrixConfig::default(), Box::new(MockMatrixHttp));
363 let caps = ch.capabilities();
364 assert!(caps.supports_threads);
365 assert!(caps.supports_reactions);
366 assert!(caps.supports_files);
367 assert!(caps.supports_editing);
368 assert!(caps.supports_deletion);
369 assert!(caps.max_message_length.is_none());
370 }
371
372 #[test]
373 fn test_matrix_streaming_mode() {
374 let ch = MatrixChannel::new(MatrixConfig::default(), Box::new(MockMatrixHttp));
375 assert_eq!(ch.streaming_mode(), StreamingMode::LongPolling);
376 }
377}