Skip to main content

rustant_core/channels/
manager.rs

1//! Channel manager — registers, connects, polls, and broadcasts across channels.
2//!
3//! Optionally holds a [`PairingManager`] for device-pairing enforcement.
4
5use super::{
6    Channel, ChannelCapabilities, ChannelMessage, ChannelStatus, MessageId, StreamingMode,
7};
8use crate::error::{ChannelError, RustantError};
9use crate::pairing::PairingManager;
10use std::collections::HashMap;
11
12/// Manages a set of registered channels.
13///
14/// When a `PairingManager` is attached, it can be used to verify that
15/// incoming DM senders are paired before processing their messages.
16pub struct ChannelManager {
17    channels: HashMap<String, Box<dyn Channel>>,
18    pairing: Option<PairingManager>,
19}
20
21impl ChannelManager {
22    pub fn new() -> Self {
23        Self {
24            channels: HashMap::new(),
25            pairing: None,
26        }
27    }
28
29    /// Attach a pairing manager for device-pairing enforcement.
30    pub fn with_pairing(mut self, pairing: PairingManager) -> Self {
31        self.pairing = Some(pairing);
32        self
33    }
34
35    /// Access the pairing manager, if present.
36    pub fn pairing(&self) -> Option<&PairingManager> {
37        self.pairing.as_ref()
38    }
39
40    /// Mutable access to the pairing manager, if present.
41    pub fn pairing_mut(&mut self) -> Option<&mut PairingManager> {
42        self.pairing.as_mut()
43    }
44
45    /// Register a channel by name.
46    pub fn register(&mut self, channel: Box<dyn Channel>) {
47        let name = channel.name().to_string();
48        // S17: Warn when overwriting an existing channel registration
49        if self.channels.contains_key(&name) {
50            tracing::warn!(
51                channel = %name,
52                "Overwriting existing channel registration — previous instance will be dropped"
53            );
54        }
55        self.channels.insert(name, channel);
56    }
57
58    /// Number of registered channels.
59    pub fn channel_count(&self) -> usize {
60        self.channels.len()
61    }
62
63    /// List all registered channel names.
64    pub fn channel_names(&self) -> Vec<&str> {
65        self.channels.keys().map(|k| k.as_str()).collect()
66    }
67
68    /// Get the status of a channel by name.
69    pub fn channel_status(&self, name: &str) -> Option<ChannelStatus> {
70        self.channels.get(name).map(|c| c.status())
71    }
72
73    /// Connect all registered channels.
74    pub async fn connect_all(&mut self) -> Vec<(String, Result<(), RustantError>)> {
75        let mut results = Vec::new();
76        for (name, channel) in &mut self.channels {
77            let result = channel.connect().await;
78            results.push((name.clone(), result));
79        }
80        results
81    }
82
83    /// Disconnect all registered channels.
84    pub async fn disconnect_all(&mut self) -> Vec<(String, Result<(), RustantError>)> {
85        let mut results = Vec::new();
86        for (name, channel) in &mut self.channels {
87            let result = channel.disconnect().await;
88            results.push((name.clone(), result));
89        }
90        results
91    }
92
93    /// Broadcast a message to all connected channels.
94    pub async fn broadcast(
95        &self,
96        msg: ChannelMessage,
97    ) -> Vec<(String, Result<MessageId, RustantError>)> {
98        let mut results = Vec::new();
99        for (name, channel) in &self.channels {
100            if channel.is_connected() {
101                let result = channel.send_message(msg.clone()).await;
102                results.push((name.clone(), result));
103            }
104        }
105        results
106    }
107
108    /// Poll all connected channels for new messages.
109    pub async fn poll_all(&self) -> Vec<(String, Result<Vec<ChannelMessage>, RustantError>)> {
110        let mut results = Vec::new();
111        for (name, channel) in &self.channels {
112            if channel.is_connected() {
113                let result = channel.receive_messages().await;
114                results.push((name.clone(), result));
115            }
116        }
117        results
118    }
119
120    /// Send a message to a specific channel by name.
121    pub async fn send_to(
122        &self,
123        channel_name: &str,
124        msg: ChannelMessage,
125    ) -> Result<MessageId, RustantError> {
126        let channel = self.channels.get(channel_name).ok_or_else(|| {
127            RustantError::Channel(ChannelError::NotConnected {
128                name: channel_name.to_string(),
129            })
130        })?;
131        if !channel.is_connected() {
132            return Err(RustantError::Channel(ChannelError::NotConnected {
133                name: channel_name.to_string(),
134            }));
135        }
136        channel.send_message(msg).await
137    }
138
139    /// Get number of connected channels.
140    pub fn connected_count(&self) -> usize {
141        self.channels.values().filter(|c| c.is_connected()).count()
142    }
143
144    /// Get the capabilities of a channel by name.
145    pub fn get_capabilities(&self, channel_name: &str) -> Option<ChannelCapabilities> {
146        self.channels.get(channel_name).map(|c| c.capabilities())
147    }
148
149    /// List names of channels that support threads.
150    pub fn channels_supporting_threads(&self) -> Vec<&str> {
151        self.channels
152            .iter()
153            .filter(|(_, c)| c.capabilities().supports_threads)
154            .map(|(name, _)| name.as_str())
155            .collect()
156    }
157
158    /// Map of channel names to their streaming modes.
159    pub fn channels_by_streaming_mode(&self) -> HashMap<&str, StreamingMode> {
160        self.channels
161            .iter()
162            .map(|(name, c)| (name.as_str(), c.streaming_mode()))
163            .collect()
164    }
165}
166
167impl Default for ChannelManager {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173/// Build a `ChannelManager` from configuration, registering real channel implementations
174/// for each enabled/present channel config.
175pub fn build_channel_manager(config: &crate::config::ChannelsConfig) -> ChannelManager {
176    let mut mgr = ChannelManager::new();
177
178    if let Some(ref cfg) = config.slack {
179        mgr.register(Box::new(super::slack::create_slack_channel(cfg.clone())));
180    }
181
182    if let Some(ref cfg) = config.telegram {
183        mgr.register(Box::new(super::telegram::create_telegram_channel(
184            cfg.clone(),
185        )));
186    }
187
188    if let Some(ref cfg) = config.discord {
189        mgr.register(Box::new(super::discord::create_discord_channel(
190            cfg.clone(),
191        )));
192    }
193
194    if let Some(ref cfg) = config.webhook {
195        mgr.register(Box::new(super::webhook::create_webhook_channel(
196            cfg.clone(),
197        )));
198    }
199
200    if let Some(ref cfg) = config.whatsapp {
201        mgr.register(Box::new(super::whatsapp::create_whatsapp_channel(
202            cfg.clone(),
203        )));
204    }
205
206    if let Some(ref cfg) = config.sms {
207        mgr.register(Box::new(super::sms::create_sms_channel(cfg.clone())));
208    }
209
210    if let Some(ref cfg) = config.matrix {
211        mgr.register(Box::new(super::matrix::create_matrix_channel(cfg.clone())));
212    }
213
214    if let Some(ref cfg) = config.teams {
215        mgr.register(Box::new(super::teams::create_teams_channel(cfg.clone())));
216    }
217
218    if let Some(ref cfg) = config.email {
219        mgr.register(Box::new(super::email::create_email_channel(cfg.clone())));
220    }
221
222    if let Some(ref cfg) = config.irc {
223        mgr.register(Box::new(super::irc::create_irc_channel(cfg.clone())));
224    }
225
226    if let Some(ref cfg) = config.signal {
227        mgr.register(Box::new(super::signal::create_signal_channel(cfg.clone())));
228    }
229
230    #[cfg(target_os = "macos")]
231    if let Some(ref cfg) = config.imessage {
232        mgr.register(Box::new(super::imessage::create_imessage_channel(
233            cfg.clone(),
234        )));
235    }
236
237    mgr
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::channels::types::{ChannelCapabilities, ChannelType, ChannelUser, StreamingMode};
244
245    /// A mock channel for testing.
246    struct MockChannel {
247        name: String,
248        channel_type: ChannelType,
249        status: ChannelStatus,
250        sent: std::sync::Arc<std::sync::Mutex<Vec<ChannelMessage>>>,
251        inbox: Vec<ChannelMessage>,
252        caps: ChannelCapabilities,
253        mode: StreamingMode,
254    }
255
256    impl MockChannel {
257        fn new(name: &str, channel_type: ChannelType) -> Self {
258            Self {
259                name: name.to_string(),
260                channel_type,
261                status: ChannelStatus::Disconnected,
262                sent: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
263                inbox: Vec::new(),
264                caps: ChannelCapabilities::default(),
265                mode: StreamingMode::default(),
266            }
267        }
268
269        fn with_inbox(mut self, messages: Vec<ChannelMessage>) -> Self {
270            self.inbox = messages;
271            self
272        }
273
274        fn with_capabilities(mut self, caps: ChannelCapabilities) -> Self {
275            self.caps = caps;
276            self
277        }
278
279        fn with_streaming_mode(mut self, mode: StreamingMode) -> Self {
280            self.mode = mode;
281            self
282        }
283    }
284
285    #[async_trait::async_trait]
286    impl Channel for MockChannel {
287        fn name(&self) -> &str {
288            &self.name
289        }
290
291        fn channel_type(&self) -> ChannelType {
292            self.channel_type
293        }
294
295        async fn connect(&mut self) -> Result<(), RustantError> {
296            self.status = ChannelStatus::Connected;
297            Ok(())
298        }
299
300        async fn disconnect(&mut self) -> Result<(), RustantError> {
301            self.status = ChannelStatus::Disconnected;
302            Ok(())
303        }
304
305        async fn send_message(&self, msg: ChannelMessage) -> Result<MessageId, RustantError> {
306            let id = msg.id.clone();
307            self.sent.lock().unwrap().push(msg);
308            Ok(id)
309        }
310
311        async fn receive_messages(&self) -> Result<Vec<ChannelMessage>, RustantError> {
312            Ok(self.inbox.clone())
313        }
314
315        fn status(&self) -> ChannelStatus {
316            self.status
317        }
318
319        fn capabilities(&self) -> ChannelCapabilities {
320            self.caps.clone()
321        }
322
323        fn streaming_mode(&self) -> StreamingMode {
324            self.mode.clone()
325        }
326    }
327
328    #[test]
329    fn test_manager_new() {
330        let mgr = ChannelManager::new();
331        assert_eq!(mgr.channel_count(), 0);
332        assert_eq!(mgr.connected_count(), 0);
333    }
334
335    #[test]
336    fn test_manager_register() {
337        let mut mgr = ChannelManager::new();
338        mgr.register(Box::new(MockChannel::new(
339            "telegram",
340            ChannelType::Telegram,
341        )));
342        mgr.register(Box::new(MockChannel::new("slack", ChannelType::Slack)));
343        assert_eq!(mgr.channel_count(), 2);
344        assert!(mgr.channel_names().contains(&"telegram"));
345    }
346
347    #[tokio::test]
348    async fn test_manager_connect_all() {
349        let mut mgr = ChannelManager::new();
350        mgr.register(Box::new(MockChannel::new("tg", ChannelType::Telegram)));
351        mgr.register(Box::new(MockChannel::new("sl", ChannelType::Slack)));
352
353        let results = mgr.connect_all().await;
354        assert_eq!(results.len(), 2);
355        for (_, result) in &results {
356            assert!(result.is_ok());
357        }
358        assert_eq!(mgr.connected_count(), 2);
359    }
360
361    #[tokio::test]
362    async fn test_manager_disconnect_all() {
363        let mut mgr = ChannelManager::new();
364        mgr.register(Box::new(MockChannel::new("tg", ChannelType::Telegram)));
365        mgr.connect_all().await;
366        assert_eq!(mgr.connected_count(), 1);
367
368        mgr.disconnect_all().await;
369        assert_eq!(mgr.connected_count(), 0);
370    }
371
372    #[tokio::test]
373    async fn test_manager_broadcast() {
374        let mut mgr = ChannelManager::new();
375        mgr.register(Box::new(MockChannel::new("tg", ChannelType::Telegram)));
376        mgr.register(Box::new(MockChannel::new("sl", ChannelType::Slack)));
377        mgr.connect_all().await;
378
379        let sender = ChannelUser::new("bot", ChannelType::Telegram);
380        let msg = ChannelMessage::text(ChannelType::Telegram, "broadcast", sender, "hello all");
381        let results = mgr.broadcast(msg).await;
382        assert_eq!(results.len(), 2);
383        for (_, result) in &results {
384            assert!(result.is_ok());
385        }
386    }
387
388    #[tokio::test]
389    async fn test_manager_broadcast_skips_disconnected() {
390        let mut mgr = ChannelManager::new();
391        mgr.register(Box::new(MockChannel::new("tg", ChannelType::Telegram)));
392        // Don't connect — should be skipped in broadcast
393
394        let sender = ChannelUser::new("bot", ChannelType::Telegram);
395        let msg = ChannelMessage::text(ChannelType::Telegram, "broadcast", sender, "hello");
396        let results = mgr.broadcast(msg).await;
397        assert_eq!(results.len(), 0); // skipped because disconnected
398    }
399
400    #[tokio::test]
401    async fn test_manager_send_to() {
402        let mut mgr = ChannelManager::new();
403        mgr.register(Box::new(MockChannel::new("tg", ChannelType::Telegram)));
404        mgr.connect_all().await;
405
406        let sender = ChannelUser::new("bot", ChannelType::Telegram);
407        let msg = ChannelMessage::text(ChannelType::Telegram, "chat", sender, "specific");
408        let result = mgr.send_to("tg", msg).await;
409        assert!(result.is_ok());
410    }
411
412    #[tokio::test]
413    async fn test_manager_send_to_not_found() {
414        let mgr = ChannelManager::new();
415        let sender = ChannelUser::new("bot", ChannelType::Telegram);
416        let msg = ChannelMessage::text(ChannelType::Telegram, "chat", sender, "test");
417        let result = mgr.send_to("nonexistent", msg).await;
418        assert!(result.is_err());
419    }
420
421    #[tokio::test]
422    async fn test_manager_poll_all() {
423        let sender = ChannelUser::new("user1", ChannelType::Telegram);
424        let inbox_msg = ChannelMessage::text(ChannelType::Telegram, "chat1", sender, "incoming");
425
426        let mut mock = MockChannel::new("tg", ChannelType::Telegram);
427        mock.status = ChannelStatus::Connected;
428        let mock = mock.with_inbox(vec![inbox_msg]);
429
430        let mut mgr = ChannelManager::new();
431        mgr.register(Box::new(mock));
432
433        let results = mgr.poll_all().await;
434        assert_eq!(results.len(), 1);
435        let (name, msgs) = &results[0];
436        assert_eq!(name, "tg");
437        let msgs = msgs.as_ref().unwrap();
438        assert_eq!(msgs.len(), 1);
439        assert_eq!(msgs[0].content.as_text(), Some("incoming"));
440    }
441
442    #[test]
443    fn test_manager_get_capabilities() {
444        let mut mgr = ChannelManager::new();
445        let caps = ChannelCapabilities {
446            supports_threads: true,
447            supports_files: true,
448            ..Default::default()
449        };
450        mgr.register(Box::new(
451            MockChannel::new("tg", ChannelType::Telegram).with_capabilities(caps.clone()),
452        ));
453        assert_eq!(mgr.get_capabilities("tg"), Some(caps));
454    }
455
456    #[test]
457    fn test_manager_channels_supporting_threads() {
458        let mut mgr = ChannelManager::new();
459        let threaded_caps = ChannelCapabilities {
460            supports_threads: true,
461            ..Default::default()
462        };
463        mgr.register(Box::new(
464            MockChannel::new("tg", ChannelType::Telegram).with_capabilities(threaded_caps),
465        ));
466        mgr.register(Box::new(MockChannel::new("wc", ChannelType::WebChat)));
467
468        let threaded = mgr.channels_supporting_threads();
469        assert_eq!(threaded.len(), 1);
470        assert!(threaded.contains(&"tg"));
471    }
472
473    #[test]
474    fn test_manager_channels_by_streaming_mode() {
475        let mut mgr = ChannelManager::new();
476        mgr.register(Box::new(
477            MockChannel::new("tg", ChannelType::Telegram)
478                .with_streaming_mode(StreamingMode::Polling { interval_ms: 1000 }),
479        ));
480        mgr.register(Box::new(
481            MockChannel::new("dc", ChannelType::Discord)
482                .with_streaming_mode(StreamingMode::WebSocket),
483        ));
484
485        let modes = mgr.channels_by_streaming_mode();
486        assert_eq!(modes.len(), 2);
487        assert_eq!(modes["tg"], StreamingMode::Polling { interval_ms: 1000 });
488        assert_eq!(modes["dc"], StreamingMode::WebSocket);
489    }
490
491    #[test]
492    fn test_manager_capability_unknown_channel() {
493        let mgr = ChannelManager::new();
494        assert!(mgr.get_capabilities("nonexistent").is_none());
495    }
496}